[libcxx] [compiler-rt] [llvm] [libc] [lldb] [lld] [flang] [mlir] [clang] [mlir][sparse] Implement parsing n out of m (PR #79935)
Yinying Li via cfe-commits
cfe-commits at lists.llvm.org
Tue Feb 6 10:22:51 PST 2024
https://github.com/yinying-lisa-li updated https://github.com/llvm/llvm-project/pull/79935
>From b4610de041d1fd9c362a4155ee50325c738eebda Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Tue, 30 Jan 2024 01:01:52 +0000
Subject: [PATCH 01/13] [mlir][sparse] Expand LevelType to 64 bit and implement
n out of m
---
mlir/include/mlir-c/Dialect/SparseTensor.h | 28 +--
.../mlir/Dialect/SparseTensor/IR/Enums.h | 209 +++++++++++-------
.../SparseTensor/IR/SparseTensorAttrDefs.td | 4 +-
.../SparseTensor/IR/SparseTensorType.h | 2 +-
.../mlir/Dialect/SparseTensor/Utils/Merger.h | 2 +-
.../ExecutionEngine/SparseTensor/Storage.h | 14 +-
.../Bindings/Python/DialectSparseTensor.cpp | 2 +-
mlir/lib/CAPI/Dialect/SparseTensor.cpp | 49 ++--
.../IR/Detail/DimLvlMapParser.cpp | 2 +
.../SparseTensor/IR/Detail/LvlTypeParser.cpp | 55 ++++-
.../SparseTensor/IR/Detail/LvlTypeParser.h | 6 +-
.../Transforms/SparseGPUCodegen.cpp | 2 +-
.../Transforms/SparseTensorCodegen.cpp | 6 +-
.../Transforms/Sparsification.cpp | 2 +-
.../Transforms/Utils/SparseTensorLevel.cpp | 2 +-
.../lib/Dialect/SparseTensor/Utils/Merger.cpp | 4 +-
.../ExecutionEngine/SparseTensor/Storage.cpp | 2 +-
mlir/test/CAPI/sparse_tensor.c | 6 +-
.../SparseTensor/GPU/gpu_matmul24_lib.mlir | 2 +-
.../SparseTensor/roundtrip_encoding.mlir | 12 +-
.../SparseTensor/sparse_fill_zero.mlir | 2 +-
.../SparseTensor/CPU/sparse_block_matmul.mlir | 2 +-
.../Dialect/SparseTensor/CPU/sparse_ds.mlir | 2 +-
.../CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir | 2 +-
.../CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir | 2 +-
.../python/dialects/sparse_tensor/dialect.py | 106 ++++-----
26 files changed, 316 insertions(+), 211 deletions(-)
diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 42d8400cb5e95..947a746b60a65 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -28,20 +28,20 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor);
typedef uint64_t MlirSparseTensorLevelType;
enum MlirBaseSparseTensorLevelType {
- MLIR_SPARSE_TENSOR_LEVEL_DENSE = 4, // 0b00001_00
- MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 8, // 0b00010_00
- MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 9, // 0b00010_01
- MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 10, // 0b00010_10
- MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 11, // 0b00010_11
- MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 16, // 0b00100_00
- MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 17, // 0b00100_01
- MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 18, // 0b00100_10
- MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 19, // 0b00100_11
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 32, // 0b01000_00
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 33, // 0b01000_01
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 34, // 0b01000_10
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 35, // 0b01000_11
- MLIR_SPARSE_TENSOR_LEVEL_TWO_OUT_OF_FOUR = 64, // 0b10000_00
+ MLIR_SPARSE_TENSOR_LEVEL_DENSE = 65536, // 0x00_00_0001_0000
+ MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 131072, // 0x00_00_0002_0000
+ MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 131073, // 0x00_00_0002_0001
+ MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 131074, // 0x00_00_0002_0002
+ MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 131075, // 0x00_00_0002_0003
+ MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 262144, // 0x00_00_0004_0000
+ MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 262145, // 0x00_00_0004_0001
+ MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 262146, // 0x00_00_0004_0002
+ MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 262147, // 0x00_00_0004_0003
+ MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 524288, // 0x00_00_0008_0000
+ MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 524289, // 0x00_00_0008_0001
+ MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 524290, // 0x00_00_0008_0002
+ MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 524291, // 0x00_00_0008_0003
+ MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 1048576, // 0x00_00_0010_0000
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 1f662e2042304..b70ac57dfd00a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -154,9 +154,10 @@ enum class Action : uint32_t {
/// This enum defines all the sparse representations supportable by
/// the SparseTensor dialect. We use a lightweight encoding to encode
-/// both the "format" per se (dense, compressed, singleton, loose_compressed,
-/// two-out-of-four) as well as the "properties" (ordered, unique). The
-/// encoding is chosen for performance of the runtime library, and thus may
+/// the "format" per se (dense, compressed, singleton, loose_compressed,
+/// n-out-of-m), the "properties" (ordered, unique) as well as n and m for
+/// NOutOfM level type.
+/// The encoding is chosen for performance of the runtime library, and thus may
/// change in future versions; consequently, client code should use the
/// predicate functions defined below, rather than relying on knowledge
/// about the particular binary encoding.
@@ -165,41 +166,74 @@ enum class Action : uint32_t {
/// where we need to store an undefined or indeterminate `LevelType`.
/// It should not be used externally, since it does not indicate an
/// actual/representable format.
+///
+/// Bit manipulations for LevelType:
+///
+/// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty |
+///
enum class LevelType : uint64_t {
- Undef = 0, // 0b00000_00
- Dense = 4, // 0b00001_00
- Compressed = 8, // 0b00010_00
- CompressedNu = 9, // 0b00010_01
- CompressedNo = 10, // 0b00010_10
- CompressedNuNo = 11, // 0b00010_11
- Singleton = 16, // 0b00100_00
- SingletonNu = 17, // 0b00100_01
- SingletonNo = 18, // 0b00100_10
- SingletonNuNo = 19, // 0b00100_11
- LooseCompressed = 32, // 0b01000_00
- LooseCompressedNu = 33, // 0b01000_01
- LooseCompressedNo = 34, // 0b01000_10
- LooseCompressedNuNo = 35, // 0b01000_11
- TwoOutOfFour = 64, // 0b10000_00
+ Undef = 0, // 0x00_00_0000_0000
+ Dense = 65536, // 0x00_00_0001_0000
+ Compressed = 131072, // 0x00_00_0002_0000
+ CompressedNu = 131073, // 0x00_00_0002_0001
+ CompressedNo = 131074, // 0x00_00_0002_0002
+ CompressedNuNo = 131075, // 0x00_00_0002_0003
+ Singleton = 262144, // 0x00_00_0004_0000
+ SingletonNu = 262145, // 0x00_00_0004_0001
+ SingletonNo = 262146, // 0x00_00_0004_0002
+ SingletonNuNo = 262147, // 0x00_00_0004_0003
+ LooseCompressed = 524288, // 0x00_00_0008_0000
+ LooseCompressedNu = 524289, // 0x00_00_0008_0001
+ LooseCompressedNo = 524290, // 0x00_00_0008_0002
+ LooseCompressedNuNo = 524291, // 0x00_00_0008_0003
+ NOutOfM = 1048576, // 0x00_00_0010_0000
};
/// This enum defines all supported storage format without the level properties.
enum class LevelFormat : uint64_t {
- Dense = 4, // 0b00001_00
- Compressed = 8, // 0b00010_00
- Singleton = 16, // 0b00100_00
- LooseCompressed = 32, // 0b01000_00
- TwoOutOfFour = 64, // 0b10000_00
+ Dense = 65536, // 0x0001_0000
+ Compressed = 131072, // 0x0002_0000
+ Singleton = 262144, // 0x0004_0000
+ LooseCompressed = 524288, // 0x0008_0000
+ NOutOfM = 1048576, // 0x0010_0000
};
/// This enum defines all the nondefault properties for storage formats.
enum class LevelPropertyNondefault : uint64_t {
- Nonunique = 1, // 0b00000_01
- Nonordered = 2, // 0b00000_10
+ Nonunique = 1, // 0x0001
+ Nonordered = 2, // 0x0002
};
+/// Get N of NOutOfM level type.
+constexpr uint64_t getN(LevelType lt) {
+ return (static_cast<uint64_t>(lt) >> 32) & 0xff;
+}
+
+/// Get M of NOutOfM level type.
+constexpr uint64_t getM(LevelType lt) {
+ return (static_cast<uint64_t>(lt) >> 40) & 0xff;
+}
+
+/// Convert N of NOutOfM level type to the stored bits.
+constexpr uint64_t nToBits(uint64_t n) { return n << 32; }
+
+/// Convert M of NOutOfM level type to the stored bits.
+constexpr uint64_t mToBits(uint64_t m) { return m << 40; }
+
+/// Check if the `LevelType` is NOutOfM (regardless of
+/// properties and block sizes).
+constexpr bool isNOutOfMLT(LevelType lt) {
+ return ((static_cast<uint64_t>(lt) & 0x100000) ==
+ static_cast<uint64_t>(LevelType::NOutOfM));
+}
+
+/// Check if the `LevelType` is NOutOfM with the correct block sizes.
+constexpr bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) {
+ return isNOutOfMLT(lt) && getN(lt) == n && getM(lt) == m;
+}
+
/// Returns string representation of the given dimension level type.
-constexpr const char *toMLIRString(LevelType lt) {
+std::string toMLIRString(LevelType lt) {
switch (lt) {
case LevelType::Undef:
return "undef";
@@ -229,21 +263,28 @@ constexpr const char *toMLIRString(LevelType lt) {
return "loose_compressed(nonordered)";
case LevelType::LooseCompressedNuNo:
return "loose_compressed(nonunique, nonordered)";
- case LevelType::TwoOutOfFour:
- return "block2_4";
+ default:
+ // If NOutOfM bit is set, print the [n, m] sizes.
+ if (isNOutOfMLT(lt)) {
+ unsigned n = getN(lt);
+ unsigned m = getM(lt);
+ return std::string("block[") + std::to_string(n) + ", " +
+ std::to_string(m) + "]";
+ }
}
return "";
}
/// Check that the `LevelType` contains a valid (possibly undefined) value.
constexpr bool isValidLT(LevelType lt) {
- const uint64_t formatBits = static_cast<uint64_t>(lt) >> 2;
- const uint64_t propertyBits = static_cast<uint64_t>(lt) & 3;
- // If undefined or dense, then must be unique and ordered.
+ const uint64_t formatBits = static_cast<uint64_t>(lt) & 0xffff0000;
+ const uint64_t propertyBits = static_cast<uint64_t>(lt) & 0xffff;
+ // If undefined/dense/NOutOfM, then must be unique and ordered.
// Otherwise, the format must be one of the known ones.
- return (formatBits <= 1 || formatBits == 16)
+ return (formatBits <= 0x10000 || formatBits == 0x100000)
? (propertyBits == 0)
- : (formatBits == 2 || formatBits == 4 || formatBits == 8);
+ : (formatBits == 0x20000 || formatBits == 0x40000 ||
+ formatBits == 0x80000);
}
/// Check if the `LevelType` is the special undefined value.
@@ -251,33 +292,28 @@ constexpr bool isUndefLT(LevelType lt) { return lt == LevelType::Undef; }
/// Check if the `LevelType` is dense (regardless of properties).
constexpr bool isDenseLT(LevelType lt) {
- return (static_cast<uint64_t>(lt) & ~3) ==
+ return (static_cast<uint64_t>(lt) & ~0xffff) ==
static_cast<uint64_t>(LevelType::Dense);
}
/// Check if the `LevelType` is compressed (regardless of properties).
constexpr bool isCompressedLT(LevelType lt) {
- return (static_cast<uint64_t>(lt) & ~3) ==
+ return (static_cast<uint64_t>(lt) & ~0xffff) ==
static_cast<uint64_t>(LevelType::Compressed);
}
/// Check if the `LevelType` is singleton (regardless of properties).
constexpr bool isSingletonLT(LevelType lt) {
- return (static_cast<uint64_t>(lt) & ~3) ==
+ return (static_cast<uint64_t>(lt) & ~0xffff) ==
static_cast<uint64_t>(LevelType::Singleton);
}
/// Check if the `LevelType` is loose compressed (regardless of properties).
constexpr bool isLooseCompressedLT(LevelType lt) {
- return (static_cast<uint64_t>(lt) & ~3) ==
+ return (static_cast<uint64_t>(lt) & ~0xffff) ==
static_cast<uint64_t>(LevelType::LooseCompressed);
}
-/// Check if the `LevelType` is 2OutOf4 (regardless of properties).
-constexpr bool is2OutOf4LT(LevelType lt) {
- return (static_cast<uint64_t>(lt) & ~3) ==
- static_cast<uint64_t>(LevelType::TwoOutOfFour);
-}
/// Check if the `LevelType` needs positions array.
constexpr bool isWithPosLT(LevelType lt) {
@@ -287,17 +323,19 @@ constexpr bool isWithPosLT(LevelType lt) {
/// Check if the `LevelType` needs coordinates array.
constexpr bool isWithCrdLT(LevelType lt) {
return isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
- is2OutOf4LT(lt);
+ isNOutOfMLT(lt);
}
/// Check if the `LevelType` is ordered (regardless of storage format).
constexpr bool isOrderedLT(LevelType lt) {
return !(static_cast<uint64_t>(lt) & 2);
+ return !(static_cast<uint64_t>(lt) & 2);
}
/// Check if the `LevelType` is unique (regardless of storage format).
constexpr bool isUniqueLT(LevelType lt) {
return !(static_cast<uint64_t>(lt) & 1);
+ return !(static_cast<uint64_t>(lt) & 1);
}
/// Convert a LevelType to its corresponding LevelFormat.
@@ -305,21 +343,25 @@ constexpr bool isUniqueLT(LevelType lt) {
constexpr std::optional<LevelFormat> getLevelFormat(LevelType lt) {
if (lt == LevelType::Undef)
return std::nullopt;
- return static_cast<LevelFormat>(static_cast<uint64_t>(lt) & ~3);
+ return static_cast<LevelFormat>(static_cast<uint64_t>(lt) & 0xffff0000);
}
/// Convert a LevelFormat to its corresponding LevelType with the given
/// properties. Returns std::nullopt when the properties are not applicable
/// for the input level format.
constexpr std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
- bool unique) {
- auto lt = static_cast<LevelType>(static_cast<uint64_t>(lf) |
- (ordered ? 0 : 2) | (unique ? 0 : 1));
+ bool unique, uint64_t n = 0,
+ uint64_t m = 0) {
+ uint64_t newN = n << 32;
+ uint64_t newM = m << 40;
+ auto lt =
+ static_cast<LevelType>(static_cast<uint64_t>(lf) | (ordered ? 0 : 2) |
+ (unique ? 0 : 1) | newN | newM);
return isValidLT(lt) ? std::optional(lt) : std::nullopt;
}
//
-// Ensure the above methods work as indended.
+// Ensure the above methods work as intended.
//
static_assert(
@@ -341,7 +383,7 @@ static_assert(
LevelFormat::LooseCompressed &&
*getLevelFormat(LevelType::LooseCompressedNuNo) ==
LevelFormat::LooseCompressed &&
- *getLevelFormat(LevelType::TwoOutOfFour) == LevelFormat::TwoOutOfFour),
+ *getLevelFormat(LevelType::NOutOfM) == LevelFormat::NOutOfM),
"getLevelFormat conversion is broken");
static_assert(
@@ -373,13 +415,28 @@ static_assert(
LevelType::LooseCompressedNo &&
*buildLevelType(LevelFormat::LooseCompressed, false, false) ==
LevelType::LooseCompressedNuNo &&
- buildLevelType(LevelFormat::TwoOutOfFour, false, true) == std::nullopt &&
- buildLevelType(LevelFormat::TwoOutOfFour, true, false) == std::nullopt &&
- buildLevelType(LevelFormat::TwoOutOfFour, false, false) == std::nullopt &&
- *buildLevelType(LevelFormat::TwoOutOfFour, true, true) ==
- LevelType::TwoOutOfFour),
+ buildLevelType(LevelFormat::NOutOfM, false, true) == std::nullopt &&
+ buildLevelType(LevelFormat::NOutOfM, true, false) == std::nullopt &&
+ buildLevelType(LevelFormat::NOutOfM, false, false) == std::nullopt &&
+ *buildLevelType(LevelFormat::NOutOfM, true, true) == LevelType::NOutOfM),
"buildLevelType conversion is broken");
+static_assert(
+ (getN(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4)) == 2 &&
+ getM(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4)) == 4 &&
+ getN(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10)) == 8 &&
+ getM(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10)) == 10),
+ "getN/M conversion is broken");
+
+static_assert(
+ (isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4),
+ 2, 4) &&
+ isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10),
+ 8, 10) &&
+ !isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 3, 4),
+ 2, 4)),
+ "isValidNOutOfMLT definition is broken");
+
static_assert(
(isValidLT(LevelType::Undef) && isValidLT(LevelType::Dense) &&
isValidLT(LevelType::Compressed) && isValidLT(LevelType::CompressedNu) &&
@@ -391,7 +448,7 @@ static_assert(
isValidLT(LevelType::LooseCompressedNu) &&
isValidLT(LevelType::LooseCompressedNo) &&
isValidLT(LevelType::LooseCompressedNuNo) &&
- isValidLT(LevelType::TwoOutOfFour)),
+ isValidLT(LevelType::NOutOfM)),
"isValidLT definition is broken");
static_assert((isDenseLT(LevelType::Dense) &&
@@ -407,7 +464,7 @@ static_assert((isDenseLT(LevelType::Dense) &&
!isDenseLT(LevelType::LooseCompressedNu) &&
!isDenseLT(LevelType::LooseCompressedNo) &&
!isDenseLT(LevelType::LooseCompressedNuNo) &&
- !isDenseLT(LevelType::TwoOutOfFour)),
+ !isDenseLT(LevelType::NOutOfM)),
"isDenseLT definition is broken");
static_assert((!isCompressedLT(LevelType::Dense) &&
@@ -423,7 +480,7 @@ static_assert((!isCompressedLT(LevelType::Dense) &&
!isCompressedLT(LevelType::LooseCompressedNu) &&
!isCompressedLT(LevelType::LooseCompressedNo) &&
!isCompressedLT(LevelType::LooseCompressedNuNo) &&
- !isCompressedLT(LevelType::TwoOutOfFour)),
+ !isCompressedLT(LevelType::NOutOfM)),
"isCompressedLT definition is broken");
static_assert((!isSingletonLT(LevelType::Dense) &&
@@ -439,7 +496,7 @@ static_assert((!isSingletonLT(LevelType::Dense) &&
!isSingletonLT(LevelType::LooseCompressedNu) &&
!isSingletonLT(LevelType::LooseCompressedNo) &&
!isSingletonLT(LevelType::LooseCompressedNuNo) &&
- !isSingletonLT(LevelType::TwoOutOfFour)),
+ !isSingletonLT(LevelType::NOutOfM)),
"isSingletonLT definition is broken");
static_assert((!isLooseCompressedLT(LevelType::Dense) &&
@@ -455,24 +512,24 @@ static_assert((!isLooseCompressedLT(LevelType::Dense) &&
isLooseCompressedLT(LevelType::LooseCompressedNu) &&
isLooseCompressedLT(LevelType::LooseCompressedNo) &&
isLooseCompressedLT(LevelType::LooseCompressedNuNo) &&
- !isLooseCompressedLT(LevelType::TwoOutOfFour)),
+ !isLooseCompressedLT(LevelType::NOutOfM)),
"isLooseCompressedLT definition is broken");
-static_assert((!is2OutOf4LT(LevelType::Dense) &&
- !is2OutOf4LT(LevelType::Compressed) &&
- !is2OutOf4LT(LevelType::CompressedNu) &&
- !is2OutOf4LT(LevelType::CompressedNo) &&
- !is2OutOf4LT(LevelType::CompressedNuNo) &&
- !is2OutOf4LT(LevelType::Singleton) &&
- !is2OutOf4LT(LevelType::SingletonNu) &&
- !is2OutOf4LT(LevelType::SingletonNo) &&
- !is2OutOf4LT(LevelType::SingletonNuNo) &&
- !is2OutOf4LT(LevelType::LooseCompressed) &&
- !is2OutOf4LT(LevelType::LooseCompressedNu) &&
- !is2OutOf4LT(LevelType::LooseCompressedNo) &&
- !is2OutOf4LT(LevelType::LooseCompressedNuNo) &&
- is2OutOf4LT(LevelType::TwoOutOfFour)),
- "is2OutOf4LT definition is broken");
+static_assert((!isNOutOfMLT(LevelType::Dense) &&
+ !isNOutOfMLT(LevelType::Compressed) &&
+ !isNOutOfMLT(LevelType::CompressedNu) &&
+ !isNOutOfMLT(LevelType::CompressedNo) &&
+ !isNOutOfMLT(LevelType::CompressedNuNo) &&
+ !isNOutOfMLT(LevelType::Singleton) &&
+ !isNOutOfMLT(LevelType::SingletonNu) &&
+ !isNOutOfMLT(LevelType::SingletonNo) &&
+ !isNOutOfMLT(LevelType::SingletonNuNo) &&
+ !isNOutOfMLT(LevelType::LooseCompressed) &&
+ !isNOutOfMLT(LevelType::LooseCompressedNu) &&
+ !isNOutOfMLT(LevelType::LooseCompressedNo) &&
+ !isNOutOfMLT(LevelType::LooseCompressedNuNo) &&
+ isNOutOfMLT(LevelType::NOutOfM)),
+ "isNOutOfMLT definition is broken");
static_assert((isOrderedLT(LevelType::Dense) &&
isOrderedLT(LevelType::Compressed) &&
@@ -487,7 +544,7 @@ static_assert((isOrderedLT(LevelType::Dense) &&
isOrderedLT(LevelType::LooseCompressedNu) &&
!isOrderedLT(LevelType::LooseCompressedNo) &&
!isOrderedLT(LevelType::LooseCompressedNuNo) &&
- isOrderedLT(LevelType::TwoOutOfFour)),
+ isOrderedLT(LevelType::NOutOfM)),
"isOrderedLT definition is broken");
static_assert((isUniqueLT(LevelType::Dense) &&
@@ -503,7 +560,7 @@ static_assert((isUniqueLT(LevelType::Dense) &&
!isUniqueLT(LevelType::LooseCompressedNu) &&
isUniqueLT(LevelType::LooseCompressedNo) &&
!isUniqueLT(LevelType::LooseCompressedNuNo) &&
- isUniqueLT(LevelType::TwoOutOfFour)),
+ isUniqueLT(LevelType::NOutOfM)),
"isUniqueLT definition is broken");
/// Bit manipulations for affine encoding.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 12c1068ae1f54..299ba0e603089 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -145,7 +145,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
- **compressed** : only nonzeros along this level are stored
- **loose_compressed** : as compressed, but allows for free space between regions
- **singleton** : a variant of the compressed format, where coordinates have no siblings
- - **block2_4** : the compression uses a 2:4 encoding per 1x4 block
+ - **block[2, 4]** : the compression uses a 2:4 encoding per 1x4 block
For a compressed level, each position interval is represented in a compact
way with a lowerbound `pos(i)` and an upperbound `pos(i+1) - 1`, which implies
@@ -374,7 +374,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
bool isCompressedLvl(::mlir::sparse_tensor::Level l) const { return isCompressedLT(getLvlType(l)); }
bool isSingletonLvl(::mlir::sparse_tensor::Level l) const { return isSingletonLT(getLvlType(l)); }
bool isLooseCompressedLvl(::mlir::sparse_tensor::Level l) const { return isLooseCompressedLT(getLvlType(l)); }
- bool isTwoOutOfFourLvl(::mlir::sparse_tensor::Level l) const { return is2OutOf4LT(getLvlType(l)); }
+ bool isNOutOfMLvl(::mlir::sparse_tensor::Level l) const { return isNOutOfMLT(getLvlType(l)); }
bool isOrderedLvl(::mlir::sparse_tensor::Level l) const { return isOrderedLT(getLvlType(l)); }
bool isUniqueLvl(::mlir::sparse_tensor::Level l) const { return isUniqueLT(getLvlType(l)); }
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 4c98129744bcd..4e2b85d35c1ac 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -291,7 +291,7 @@ class SparseTensorType {
return isLooseCompressedLT(getLvlType(l));
}
bool isSingletonLvl(Level l) const { return isSingletonLT(getLvlType(l)); }
- bool is2OutOf4Lvl(Level l) const { return is2OutOf4LT(getLvlType(l)); }
+ bool isNOutOfMLvl(Level l) const { return isNOutOfMLT(getLvlType(l)); }
bool isOrderedLvl(Level l) const { return isOrderedLT(getLvlType(l)); }
bool isUniqueLvl(Level l) const { return isUniqueLT(getLvlType(l)); }
bool isWithPos(Level l) const { return isWithPosLT(getLvlType(l)); }
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 4a34bb2e003e8..490ef3071af1b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -510,7 +510,7 @@ class Merger {
if (isLvlWithNonTrivialIdxExp(b)) {
auto lt = getLoopDependentLevelType(b);
return isCompressedLT(lt) || isSingletonLT(lt) ||
- isLooseCompressedLT(lt) || is2OutOf4LT(lt);
+ isLooseCompressedLT(lt) || isNOutOfMLT(lt);
}
return false;
}
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index 01c5f2382ffe6..1d8d9bcfb3b2c 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -124,7 +124,7 @@ class SparseTensorStorageBase {
bool isSingletonLvl(uint64_t l) const { return isSingletonLT(getLvlType(l)); }
/// Safely checks if the level uses 2 out of 4 storage.
- bool is2OutOf4Lvl(uint64_t l) const { return is2OutOf4LT(getLvlType(l)); }
+ bool isNOutOfMLvl(uint64_t l) const { return isNOutOfMLT(getLvlType(l)); }
/// Safely checks if the level is ordered.
bool isOrderedLvl(uint64_t l) const { return isOrderedLT(getLvlType(l)); }
@@ -450,7 +450,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
void appendCrd(uint64_t lvl, uint64_t full, uint64_t crd) {
if (!isDenseLvl(lvl)) {
assert(isCompressedLvl(lvl) || isLooseCompressedLvl(lvl) ||
- isSingletonLvl(lvl) || is2OutOf4Lvl(lvl));
+ isSingletonLvl(lvl) || isNOutOfMLvl(lvl));
coordinates[lvl].push_back(detail::checkOverflowCast<C>(crd));
} else { // Dense level.
assert(crd >= full && "Coordinate was already filled");
@@ -473,7 +473,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
return positions[l][parentSz];
if (isLooseCompressedLvl(l))
return positions[l][2 * parentSz - 1];
- if (isSingletonLvl(l) || is2OutOf4Lvl(l))
+ if (isSingletonLvl(l) || isNOutOfMLvl(l))
return parentSz; // new size same as the parent
assert(isDenseLvl(l));
return parentSz * getLvlSize(l);
@@ -527,7 +527,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
uint64_t pos = coordinates[l].size();
positions[l].insert(positions[l].end(), 2 * count,
detail::checkOverflowCast<P>(pos));
- } else if (isSingletonLvl(l) || is2OutOf4Lvl(l)) {
+ } else if (isSingletonLvl(l) || isNOutOfMLvl(l)) {
return; // Nothing to finalize.
} else { // Dense dimension.
assert(isDenseLvl(l));
@@ -624,7 +624,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
lvlCursor[l] = static_cast<uint64_t>(coordinatesL[pos]);
toCOO(pos, l + 1, dimCoords);
}
- } else if (isSingletonLvl(l) || is2OutOf4Lvl(l)) {
+ } else if (isSingletonLvl(l) || isNOutOfMLvl(l)) {
assert(parentPos < coordinates[l].size());
lvlCursor[l] = static_cast<uint64_t>(coordinates[l][parentPos]);
toCOO(parentPos, l + 1, dimCoords);
@@ -721,7 +721,7 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
} else if (isSingletonLvl(l)) {
coordinates[l].reserve(sz);
sz = 1;
- } else if (is2OutOf4Lvl(l)) {
+ } else if (isNOutOfMLvl(l)) {
assert(l == lvlRank - 1 && "unexpected 2:4 usage");
sz = detail::checkedMul(sz, lvlSizes[l]) / 2;
coordinates[l].reserve(sz);
@@ -791,7 +791,7 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
}
} else if (isSingletonLvl(l)) {
assert(0 && "general singleton not supported yet");
- } else if (is2OutOf4Lvl(l)) {
+ } else if (isNOutOfMLvl(l)) {
assert(0 && "2Out4 not supported yet");
} else {
assert(isDenseLvl(l));
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index 698367a1addaf..607534c615643 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -25,7 +25,7 @@ using namespace mlir::python::adaptors;
static void populateDialectSparseTensorSubmodule(const py::module &m) {
py::enum_<MlirBaseSparseTensorLevelType>(m, "LevelType", py::module_local())
.value("dense", MLIR_SPARSE_TENSOR_LEVEL_DENSE)
- .value("compressed24", MLIR_SPARSE_TENSOR_LEVEL_TWO_OUT_OF_FOUR)
+ .value("n_out_of_m", MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M)
.value("compressed", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED)
.value("compressed_nu", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU)
.value("compressed_no", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO)
diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
index e4534ad132385..a34b9a29b0e90 100644
--- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp
+++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
@@ -20,25 +20,36 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor,
mlir::sparse_tensor::SparseTensorDialect)
// Ensure the C-API enums are int-castable to C++ equivalents.
-static_assert(static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_DENSE) ==
- static_cast<int>(LevelType::Dense) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) ==
- static_cast<int>(LevelType::Compressed) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU) ==
- static_cast<int>(LevelType::CompressedNu) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO) ==
- static_cast<int>(LevelType::CompressedNo) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO) ==
- static_cast<int>(LevelType::CompressedNuNo) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) ==
- static_cast<int>(LevelType::Singleton) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU) ==
- static_cast<int>(LevelType::SingletonNu) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO) ==
- static_cast<int>(LevelType::SingletonNo) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO) ==
- static_cast<int>(LevelType::SingletonNuNo),
- "MlirSparseTensorLevelType (C-API) and LevelType (C++) mismatch");
+static_assert(
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_DENSE) ==
+ static_cast<int>(LevelType::Dense) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) ==
+ static_cast<int>(LevelType::Compressed) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU) ==
+ static_cast<int>(LevelType::CompressedNu) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO) ==
+ static_cast<int>(LevelType::CompressedNo) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO) ==
+ static_cast<int>(LevelType::CompressedNuNo) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) ==
+ static_cast<int>(LevelType::Singleton) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU) ==
+ static_cast<int>(LevelType::SingletonNu) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO) ==
+ static_cast<int>(LevelType::SingletonNo) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO) ==
+ static_cast<int>(LevelType::SingletonNuNo) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED) ==
+ static_cast<int>(LevelType::LooseCompressed) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU) ==
+ static_cast<int>(LevelType::LooseCompressedNu) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO) ==
+ static_cast<int>(LevelType::LooseCompressedNo) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO) ==
+ static_cast<int>(LevelType::LooseCompressedNuNo) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M) ==
+ static_cast<int>(LevelType::NOutOfM),
+ "MlirSparseTensorLevelType (C-API) and LevelType (C++) mismatch");
bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
return isa<SparseTensorEncodingAttr>(unwrap(attr));
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
index 56b435c57d30a..95874d4857fc8 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
@@ -299,6 +299,8 @@ ParseResult DimLvlMapParser::parseLvlSpec(bool requireLvlVarBinding) {
FAILURE_IF_FAILED(type)
lvlSpecs.emplace_back(var, expr, static_cast<LevelType>(*type));
+ llvm::errs() << "type = " << toMLIRString(static_cast<LevelType>(*type))
+ << "\n";
return success();
}
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index eb7ea63a4e88b..14ebe14b49f64 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -29,12 +29,21 @@ using namespace mlir::sparse_tensor::ir_detail;
// `LvlTypeParser` implementation.
//===----------------------------------------------------------------------===//
-FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
+FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
StringRef base;
const auto loc = parser.getCurrentLocation();
ERROR_IF(failed(parser.parseOptionalKeyword(&base)),
"expected valid level format (e.g. dense, compressed or singleton)")
- uint8_t properties = 0;
+ uint64_t properties = 0;
+ SmallVector<unsigned> blockSizes;
+
+ if (base.compare("block") == 0) {
+ ParseResult res = parser.parseCommaSeparatedList(
+ mlir::OpAsmParser::Delimiter::OptionalSquare,
+ [&]() -> ParseResult { return parseBlockSize(parser, &blockSizes); },
+ " in block n out of m");
+ FAILURE_IF_FAILED(res)
+ }
ParseResult res = parser.parseCommaSeparatedList(
mlir::OpAsmParser::Delimiter::OptionalParen,
@@ -44,15 +53,21 @@ FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
// Set the base bit for properties.
if (base.compare("dense") == 0) {
- properties |= static_cast<uint8_t>(LevelFormat::Dense);
+ properties |= static_cast<uint64_t>(LevelFormat::Dense);
} else if (base.compare("compressed") == 0) {
- properties |= static_cast<uint8_t>(LevelFormat::Compressed);
- } else if (base.compare("block2_4") == 0) {
- properties |= static_cast<uint8_t>(LevelFormat::TwoOutOfFour);
+ properties |= static_cast<uint64_t>(LevelFormat::Compressed);
+ } else if (base.compare("block") == 0) {
+ if (blockSizes.size() != 2) {
+ parser.emitError(loc, "expected exactly 2 block sizes");
+ return failure();
+ }
+ properties |= static_cast<uint64_t>(LevelFormat::NOutOfM);
+ properties |= nToBits(blockSizes[0]) | mToBits(blockSizes[1]);
+ llvm::errs() << "properties1: " << properties << "\n";
} else if (base.compare("loose_compressed") == 0) {
- properties |= static_cast<uint8_t>(LevelFormat::LooseCompressed);
+ properties |= static_cast<uint64_t>(LevelFormat::LooseCompressed);
} else if (base.compare("singleton") == 0) {
- properties |= static_cast<uint8_t>(LevelFormat::Singleton);
+ properties |= static_cast<uint64_t>(LevelFormat::Singleton);
} else {
parser.emitError(loc, "unknown level format: ") << base;
return failure();
@@ -64,15 +79,15 @@ FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
}
ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
- uint8_t *properties) const {
+ uint64_t *properties) const {
StringRef strVal;
auto loc = parser.getCurrentLocation();
ERROR_IF(failed(parser.parseOptionalKeyword(&strVal)),
"expected valid level property (e.g. nonordered, nonunique or high)")
if (strVal.compare("nonunique") == 0) {
- *properties |= static_cast<uint8_t>(LevelPropertyNondefault::Nonunique);
+ *properties |= static_cast<uint64_t>(LevelPropertyNondefault::Nonunique);
} else if (strVal.compare("nonordered") == 0) {
- *properties |= static_cast<uint8_t>(LevelPropertyNondefault::Nonordered);
+ *properties |= static_cast<uint64_t>(LevelPropertyNondefault::Nonordered);
} else {
parser.emitError(loc, "unknown level property: ") << strVal;
return failure();
@@ -80,4 +95,22 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
return success();
}
+ParseResult
+LvlTypeParser::parseBlockSize(AsmParser &parser,
+ SmallVector<unsigned> *blockSizes) const {
+ int intVal;
+ auto loc = parser.getCurrentLocation();
+ OptionalParseResult intValParseResult = parser.parseOptionalInteger(intVal);
+ if (intValParseResult.has_value()) {
+ if (failed(*intValParseResult)) {
+ parser.emitError(loc, "failed to parse block size");
+ return failure();
+ }
+ blockSizes->push_back(intVal);
+ return success();
+ }
+ parser.emitError(loc, "expected valid integer for block size");
+ return failure();
+}
+
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
index 5e2f11b75d4da..78ae667f97923 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
@@ -18,10 +18,12 @@ namespace ir_detail {
class LvlTypeParser {
public:
LvlTypeParser() = default;
- FailureOr<uint8_t> parseLvlType(AsmParser &parser) const;
+ FailureOr<uint64_t> parseLvlType(AsmParser &parser) const;
private:
- ParseResult parseProperty(AsmParser &parser, uint8_t *properties) const;
+ ParseResult parseProperty(AsmParser &parser, uint64_t *properties) const;
+ ParseResult parseBlockSize(AsmParser &parser,
+ SmallVector<unsigned> *blockSizes) const;
};
} // namespace ir_detail
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index 87a37a7926e9e..23676eccdfb28 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -451,7 +451,7 @@ static bool isAdmissibleBSR(SparseTensorType &aTp) {
/// Test for 2:4 matrix with suitable metadata.
static bool isAdmissible24(SparseTensorType &aTp) {
return aTp.getDimRank() == 2 && aTp.getLvlRank() == 3 && aTp.isDenseLvl(0) &&
- aTp.isDenseLvl(1) && aTp.is2OutOf4Lvl(2) && isAdmissibleMetaData(aTp);
+ aTp.isDenseLvl(1) && aTp.isNOutOfMLvl(2) && isAdmissibleMetaData(aTp);
}
/// Test for conversion into 2:4 matrix.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 491501a3381b9..d4459c6ea1e52 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -130,7 +130,7 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
/*value=*/posZero, /*repeat=*/linear);
return;
- } else if (isSingletonLT(lt) || is2OutOf4LT(lt)) {
+ } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
return; // nothing to do
}
// Keep compounding the size, but nothing needs to be initialized
@@ -409,7 +409,7 @@ static void genEndInsert(OpBuilder &builder, Location loc,
}
} else {
assert(isDenseLT(lt) || isLooseCompressedLT(lt) || isSingletonLT(lt) ||
- is2OutOf4LT(lt));
+ isNOutOfMLT(lt));
}
}
}
@@ -488,7 +488,7 @@ class SparseInsertGenerator
}
parentPos =
genCompressed(builder, loc, desc, coords, value, parentPos, lvl);
- } else if (isSingletonLT(lt) || is2OutOf4LT(lt)) {
+ } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
// Create:
// coordinates[lvl].push_back(coords[lvl])
// positions[lvl] = positions[lvl-1]
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index ab38ab5cc3f78..8f2ae60b311f7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -891,7 +891,7 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
assert(curr == env.merger().loop(b));
Value clause;
if (isCompressedLT(lt) || isSingletonLT(lt) ||
- isLooseCompressedLT(lt) || is2OutOf4LT(lt)) {
+ isLooseCompressedLT(lt) || isNOutOfMLT(lt)) {
assert(lvl.has_value());
const Value crd = env.emitter().getCoord(tid, *lvl);
const Value lvar = env.getLoopVar(curr);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index 04b49c320f07a..320f87ce780ac 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -1294,7 +1294,7 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
Value crd = genToCoordinates(b, l, t, lvl);
return std::make_unique<SingletonLevel>(tid, lvl, lt, sz, crd);
}
- case LevelFormat::TwoOutOfFour: {
+ case LevelFormat::NOutOfM: {
Value crd = genToCoordinates(b, l, t, lvl);
return std::make_unique<TwoOutFourLevel>(tid, lvl, lt, sz, crd);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 6cdf5f8c0168b..96537cbb0c483 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -489,7 +489,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
const auto lt = getLvlType(b);
if (!isCompressedLT(lt) && !isSingletonLT(lt) &&
- !isLooseCompressedLT(lt) && !is2OutOf4LT(lt)) {
+ !isLooseCompressedLT(lt) && !isNOutOfMLT(lt)) {
if (reset)
simple.reset(b);
reset = true;
@@ -670,7 +670,7 @@ bool Merger::hasAnySparse(const BitVector &bits) const {
for (TensorLoopId b : bits.set_bits()) {
const auto lt = getLvlType(b);
if (isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
- is2OutOf4LT(lt))
+ isNOutOfMLT(lt))
return true;
}
return hasSparseIdxReduction(bits);
diff --git a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
index 0c7b3a228a65c..9e8b240899d80 100644
--- a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
@@ -45,7 +45,7 @@ SparseTensorStorageBase::SparseTensorStorageBase( // NOLINT
for (uint64_t l = 0; l < lvlRank; l++) {
assert(lvlSizes[l] > 0 && "Level size zero has trivial storage");
assert(isDenseLvl(l) || isCompressedLvl(l) || isLooseCompressedLvl(l) ||
- isSingletonLvl(l) || is2OutOf4Lvl(l));
+ isSingletonLvl(l) || isNOutOfMLvl(l));
}
}
diff --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c
index 4f1d397517548..112a11a43dec6 100644
--- a/mlir/test/CAPI/sparse_tensor.c
+++ b/mlir/test/CAPI/sparse_tensor.c
@@ -37,9 +37,9 @@ static int testRoundtripEncoding(MlirContext ctx) {
mlirSparseTensorEncodingAttrGetDimToLvl(originalAttr);
// CHECK: (d0, d1)[s0] -> (s0, d0, d1)
mlirAffineMapDump(dimToLvl);
- // CHECK: level_type: 4
- // CHECK: level_type: 8
- // CHECK: level_type: 8
+ // CHECK: level_type: 65536
+ // CHECK: level_type: 131072
+ // CHECK: level_type: 131072
MlirAffineMap lvlToDim =
mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr);
int lvlRank = mlirSparseTensorEncodingGetLvlRank(originalAttr);
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir
index 6fe7ec906f30e..e4884a9bf393f 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir
@@ -4,7 +4,7 @@
map = ( i, j ) ->
( i : dense,
j floordiv 4 : dense,
- j mod 4 : block2_4
+ j mod 4 : block[2, 4]
)
}>
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index 20702bb985028..966a7ff2d38e1 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -207,12 +207,12 @@ func.func private @BSR_explicit(%arg0: tensor<?x?xf64, #BSR_explicit>) {
map = ( i, j ) ->
( i : dense,
j floordiv 4 : dense,
- j mod 4 : block2_4
+ j mod 4 : block[2, 4]
),
crdWidth = 8 // we would even like just 2-bits
}>
-// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : block2_4), crdWidth = 8 }>
+// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : block[2, 4]), crdWidth = 8 }>
// CHECK-LABEL: func private @NV_24(
// CHECK-SAME: tensor<?x?xf64, #[[$NV_24]]>
func.func private @NV_24(%arg0: tensor<?x?xf64, #NV_24>) {
@@ -226,11 +226,11 @@ func.func private @NV_24(%arg0: tensor<?x?xf64, #NV_24>) {
( i : dense,
j : dense,
k floordiv 4 : dense,
- k mod 4 : block2_4
+ k mod 4 : block[2, 4]
)
}>
-// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 floordiv 4 : dense, d2 mod 4 : block2_4) }>
+// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 floordiv 4 : dense, d2 mod 4 : block[2, 4]) }>
// CHECK-LABEL: func private @NV_24(
// CHECK-SAME: tensor<?x?x?xf64, #[[$NV_24]]>
func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
@@ -244,11 +244,11 @@ func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
( i : dense,
k floordiv 4 : dense,
j : dense,
- k mod 4 : block2_4
+ k mod 4 : block[2, 4]
)
}>
-// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d2 floordiv 4 : dense, d1 : dense, d2 mod 4 : block2_4) }>
+// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d2 floordiv 4 : dense, d1 : dense, d2 mod 4 : block[2, 4]) }>
// CHECK-LABEL: func private @NV_24(
// CHECK-SAME: tensor<?x?x?xf64, #[[$NV_24]]>
func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
index 7c494b2bcfe1d..d04fbe8ed5c22 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
@@ -14,7 +14,7 @@
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant true
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 100 : index
// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 300 : index
-// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 8 : i64
+// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 131072 : i64
// CHECK: %[[VAL_12:.*]] = memref.alloca() : memref<2xi64>
// CHECK: %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi64> to memref<?xi64>
// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi64>
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
index 4bc080fc538fc..554d6207aef7e 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
@@ -59,7 +59,7 @@
map = ( i, j ) ->
( i : dense,
j floordiv 4 : dense,
- j mod 4 : block2_4
+ j mod 4 : block[2, 4]
),
}>
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir
index df5b48a3b6ece..9935d7c69e63a 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir
@@ -41,7 +41,7 @@
#NV_24 = #sparse_tensor.encoding<{
map = ( i, j ) -> ( i : dense,
j floordiv 4 : dense,
- j mod 4 : block2_4),
+ j mod 4 : block[2, 4]),
crdWidth = 8
}>
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
index 17b50b46d073a..25454f5c06b45 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
@@ -20,7 +20,7 @@
map = ( i, j ) ->
( i : dense,
j floordiv 4 : dense,
- j mod 4 : block2_4
+ j mod 4 : block[2, 4]
)
}>
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
index eb99a027a8860..da735b4a3b58a 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
@@ -20,7 +20,7 @@
map = ( i, j ) ->
( i : dense,
j floordiv 4 : dense,
- j mod 4 : block2_4
+ j mod 4 : block[2, 4]
)
}>
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 946a224dab064..bdad57b5066cc 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -13,22 +13,22 @@ def run(f):
# CHECK-LABEL: TEST: testEncodingAttr1D
@run
def testEncodingAttr1D():
- with Context() as ctx:
- parsed = Attribute.parse(
- "#sparse_tensor.encoding<{"
- " map = (d0) -> (d0 : compressed),"
- " posWidth = 16,"
- " crdWidth = 32"
- "}>"
- )
- # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 16, crdWidth = 32 }>
- print(parsed)
-
- casted = st.EncodingAttr(parsed)
- # CHECK: equal: True
- print(f"equal: {casted == parsed}")
-
- # CHECK: lvl_types: [8]
+ with Context() as ctx:
+ parsed = Attribute.parse(
+ "#sparse_tensor.encoding<{"
+ " map = (d0) -> (d0 : compressed),"
+ " posWidth = 16,"
+ " crdWidth = 32"
+ "}>"
+ )
+ # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 16, crdWidth = 32 }>
+ print(parsed)
+
+ casted = st.EncodingAttr(parsed)
+ # CHECK: equal: True
+ print(f"equal: {casted == parsed}")
+
+ # CHECK: lvl_types: [131072]
print(f"lvl_types: {casted.lvl_types}")
# CHECK: dim_to_lvl: (d0) -> (d0)
print(f"dim_to_lvl: {casted.dim_to_lvl}")
@@ -39,38 +39,38 @@ def testEncodingAttr1D():
# CHECK: crd_width: 32
print(f"crd_width: {casted.crd_width}")
- created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0)
- # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
- print(created)
- # CHECK: created_equal: False
- print(f"created_equal: {created == casted}")
+ created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0)
+ # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
+ print(created)
+ # CHECK: created_equal: False
+ print(f"created_equal: {created == casted}")
- # Verify that the factory creates an instance of the proper type.
- # CHECK: is_proper_instance: True
- print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
- # CHECK: created_pos_width: 0
- print(f"created_pos_width: {created.pos_width}")
+ # Verify that the factory creates an instance of the proper type.
+ # CHECK: is_proper_instance: True
+ print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
+ # CHECK: created_pos_width: 0
+ print(f"created_pos_width: {created.pos_width}")
# CHECK-LABEL: TEST: testEncodingAttr2D
@run
def testEncodingAttr2D():
- with Context() as ctx:
- parsed = Attribute.parse(
- "#sparse_tensor.encoding<{"
- " map = (d0, d1) -> (d1 : dense, d0 : compressed),"
- " posWidth = 8,"
- " crdWidth = 32"
- "}>"
- )
- # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
- print(parsed)
-
- casted = st.EncodingAttr(parsed)
- # CHECK: equal: True
- print(f"equal: {casted == parsed}")
-
- # CHECK: lvl_types: [4, 8]
+ with Context() as ctx:
+ parsed = Attribute.parse(
+ "#sparse_tensor.encoding<{"
+ " map = (d0, d1) -> (d1 : dense, d0 : compressed),"
+ " posWidth = 8,"
+ " crdWidth = 32"
+ "}>"
+ )
+ # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
+ print(parsed)
+
+ casted = st.EncodingAttr(parsed)
+ # CHECK: equal: True
+ print(f"equal: {casted == parsed}")
+
+ # CHECK: lvl_types: [65536, 131072]
print(f"lvl_types: {casted.lvl_types}")
# CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
print(f"dim_to_lvl: {casted.dim_to_lvl}")
@@ -81,17 +81,17 @@ def testEncodingAttr2D():
# CHECK: crd_width: 32
print(f"crd_width: {casted.crd_width}")
- created = st.EncodingAttr.get(
- casted.lvl_types,
- casted.dim_to_lvl,
- casted.lvl_to_dim,
- 8,
- 32,
- )
- # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
- print(created)
- # CHECK: created_equal: True
- print(f"created_equal: {created == casted}")
+ created = st.EncodingAttr.get(
+ casted.lvl_types,
+ casted.dim_to_lvl,
+ casted.lvl_to_dim,
+ 8,
+ 32,
+ )
+ # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
+ print(created)
+ # CHECK: created_equal: True
+ print(f"created_equal: {created == casted}")
# CHECK-LABEL: TEST: testEncodingAttrOnTensorType
>From e96819adce665d7dad77b1270b132e8804c4930f Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Tue, 30 Jan 2024 02:58:18 +0000
Subject: [PATCH 02/13] format
---
.../include/mlir/Dialect/SparseTensor/IR/Enums.h | 8 ++------
.../SparseTensor/IR/Detail/DimLvlMapParser.cpp | 2 --
.../SparseTensor/IR/Detail/LvlTypeParser.cpp | 1 -
.../SparseTensor/IR/SparseTensorDialect.cpp | 16 ++++++++++++++--
4 files changed, 16 insertions(+), 11 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index b70ac57dfd00a..fe697bac5673e 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -233,7 +233,7 @@ constexpr bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) {
}
/// Returns string representation of the given dimension level type.
-std::string toMLIRString(LevelType lt) {
+constexpr const char *toMLIRString(LevelType lt) {
switch (lt) {
case LevelType::Undef:
return "undef";
@@ -264,12 +264,8 @@ std::string toMLIRString(LevelType lt) {
case LevelType::LooseCompressedNuNo:
return "loose_compressed(nonunique, nonordered)";
default:
- // If NOutOfM bit is set, print the [n, m] sizes.
if (isNOutOfMLT(lt)) {
- unsigned n = getN(lt);
- unsigned m = getM(lt);
- return std::string("block[") + std::to_string(n) + ", " +
- std::to_string(m) + "]";
+ return "block";
}
}
return "";
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
index 95874d4857fc8..56b435c57d30a 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
@@ -299,8 +299,6 @@ ParseResult DimLvlMapParser::parseLvlSpec(bool requireLvlVarBinding) {
FAILURE_IF_FAILED(type)
lvlSpecs.emplace_back(var, expr, static_cast<LevelType>(*type));
- llvm::errs() << "type = " << toMLIRString(static_cast<LevelType>(*type))
- << "\n";
return success();
}
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index 14ebe14b49f64..993ad9be8a012 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -63,7 +63,6 @@ FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
}
properties |= static_cast<uint64_t>(LevelFormat::NOutOfM);
properties |= nToBits(blockSizes[0]) | mToBits(blockSizes[1]);
- llvm::errs() << "properties1: " << properties << "\n";
} else if (base.compare("loose_compressed") == 0) {
properties |= static_cast<uint64_t>(LevelFormat::LooseCompressed);
} else if (base.compare("singleton") == 0) {
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 27125bc7ed45e..67b1d7974fa25 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -613,16 +613,28 @@ void SparseTensorEncodingAttr::printDimensions(
}
}
+std::string getNOutOfMString(LevelType lt) {
+ if (isNOutOfMLT(lt)) {
+ unsigned n = getN(lt);
+ unsigned m = getM(lt);
+ auto output = "[" + std::to_string(n) + ", " + std::to_string(m) + "]";
+ return output;
+ }
+ return "";
+}
+
void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
ArrayRef<LevelType> lvlTypes) const {
for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
map.getResult(i).print(printer.getStream());
- printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
+ printer << " : " << toMLIRString(lvlTypes[i])
+ << getNOutOfMString(lvlTypes[i]) << ", ";
}
if (map.getNumResults() >= 1) {
auto lastIndex = map.getNumResults() - 1;
map.getResult(lastIndex).print(printer.getStream());
- printer << " : " << toMLIRString(lvlTypes[lastIndex]);
+ printer << " : " << toMLIRString(lvlTypes[lastIndex])
+ << getNOutOfMString(lvlTypes[lastIndex]);
}
}
>From a6dcfac959d8837733fafb32e427f78058b6b9ec Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Tue, 30 Jan 2024 03:13:30 +0000
Subject: [PATCH 03/13] python format
---
.../mlir/Dialect/SparseTensor/IR/Enums.h | 1 -
.../python/dialects/sparse_tensor/dialect.py | 98 +++++++++----------
2 files changed, 49 insertions(+), 50 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index fe697bac5673e..57e324412f540 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -310,7 +310,6 @@ constexpr bool isLooseCompressedLT(LevelType lt) {
static_cast<uint64_t>(LevelType::LooseCompressed);
}
-
/// Check if the `LevelType` needs positions array.
constexpr bool isWithPosLT(LevelType lt) {
return isCompressedLT(lt) || isLooseCompressedLT(lt);
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index bdad57b5066cc..412c5797067b7 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -13,20 +13,20 @@ def run(f):
# CHECK-LABEL: TEST: testEncodingAttr1D
@run
def testEncodingAttr1D():
- with Context() as ctx:
- parsed = Attribute.parse(
- "#sparse_tensor.encoding<{"
- " map = (d0) -> (d0 : compressed),"
- " posWidth = 16,"
- " crdWidth = 32"
- "}>"
- )
- # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 16, crdWidth = 32 }>
- print(parsed)
-
- casted = st.EncodingAttr(parsed)
- # CHECK: equal: True
- print(f"equal: {casted == parsed}")
+ with Context() as ctx:
+ parsed = Attribute.parse(
+ "#sparse_tensor.encoding<{"
+ " map = (d0) -> (d0 : compressed),"
+ " posWidth = 16,"
+ " crdWidth = 32"
+ "}>"
+ )
+ # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 16, crdWidth = 32 }>
+ print(parsed)
+
+ casted = st.EncodingAttr(parsed)
+ # CHECK: equal: True
+ print(f"equal: {casted == parsed}")
# CHECK: lvl_types: [131072]
print(f"lvl_types: {casted.lvl_types}")
@@ -39,36 +39,36 @@ def testEncodingAttr1D():
# CHECK: crd_width: 32
print(f"crd_width: {casted.crd_width}")
- created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0)
- # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
- print(created)
- # CHECK: created_equal: False
- print(f"created_equal: {created == casted}")
+ created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0)
+ # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
+ print(created)
+ # CHECK: created_equal: False
+ print(f"created_equal: {created == casted}")
- # Verify that the factory creates an instance of the proper type.
- # CHECK: is_proper_instance: True
- print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
- # CHECK: created_pos_width: 0
- print(f"created_pos_width: {created.pos_width}")
+ # Verify that the factory creates an instance of the proper type.
+ # CHECK: is_proper_instance: True
+ print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
+ # CHECK: created_pos_width: 0
+ print(f"created_pos_width: {created.pos_width}")
# CHECK-LABEL: TEST: testEncodingAttr2D
@run
def testEncodingAttr2D():
- with Context() as ctx:
- parsed = Attribute.parse(
- "#sparse_tensor.encoding<{"
- " map = (d0, d1) -> (d1 : dense, d0 : compressed),"
- " posWidth = 8,"
- " crdWidth = 32"
- "}>"
- )
- # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
- print(parsed)
-
- casted = st.EncodingAttr(parsed)
- # CHECK: equal: True
- print(f"equal: {casted == parsed}")
+ with Context() as ctx:
+ parsed = Attribute.parse(
+ "#sparse_tensor.encoding<{"
+ " map = (d0, d1) -> (d1 : dense, d0 : compressed),"
+ " posWidth = 8,"
+ " crdWidth = 32"
+ "}>"
+ )
+ # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
+ print(parsed)
+
+ casted = st.EncodingAttr(parsed)
+ # CHECK: equal: True
+ print(f"equal: {casted == parsed}")
# CHECK: lvl_types: [65536, 131072]
print(f"lvl_types: {casted.lvl_types}")
@@ -81,17 +81,17 @@ def testEncodingAttr2D():
# CHECK: crd_width: 32
print(f"crd_width: {casted.crd_width}")
- created = st.EncodingAttr.get(
- casted.lvl_types,
- casted.dim_to_lvl,
- casted.lvl_to_dim,
- 8,
- 32,
- )
- # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
- print(created)
- # CHECK: created_equal: True
- print(f"created_equal: {created == casted}")
+ created = st.EncodingAttr.get(
+ casted.lvl_types,
+ casted.dim_to_lvl,
+ casted.lvl_to_dim,
+ 8,
+ 32,
+ )
+ # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
+ print(created)
+ # CHECK: created_equal: True
+ print(f"created_equal: {created == casted}")
# CHECK-LABEL: TEST: testEncodingAttrOnTensorType
>From 2ef4a749be636c28008a03c07f24c566e2e8d717 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Tue, 30 Jan 2024 19:37:09 +0000
Subject: [PATCH 04/13] more edits for n:m
---
.../SparseTensor/IR/SparseTensorAttrDefs.td | 2 +-
.../ExecutionEngine/SparseTensor/Storage.h | 4 ++--
.../SparseTensor/IR/Detail/LvlTypeParser.cpp | 6 +++---
.../SparseTensor/IR/Detail/LvlTypeParser.h | 4 ++--
.../SparseTensor/roundtrip_encoding.mlir | 18 ++++++++++++++++++
5 files changed, 26 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 299ba0e603089..08ba96d437045 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -145,7 +145,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
- **compressed** : only nonzeros along this level are stored
- **loose_compressed** : as compressed, but allows for free space between regions
- **singleton** : a variant of the compressed format, where coordinates have no siblings
- - **block[2, 4]** : the compression uses a 2:4 encoding per 1x4 block
+ - **block[n, m]** : the compression uses a n:m encoding per 1xm block
For a compressed level, each position interval is represented in a compact
way with a lowerbound `pos(i)` and an upperbound `pos(i+1) - 1`, which implies
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index 1d8d9bcfb3b2c..14809381517e7 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -123,7 +123,7 @@ class SparseTensorStorageBase {
/// Safely checks if the level uses singleton storage.
bool isSingletonLvl(uint64_t l) const { return isSingletonLT(getLvlType(l)); }
- /// Safely checks if the level uses 2 out of 4 storage.
+ /// Safely checks if the level uses n out of m storage.
bool isNOutOfMLvl(uint64_t l) const { return isNOutOfMLT(getLvlType(l)); }
/// Safely checks if the level is ordered.
@@ -792,7 +792,7 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
} else if (isSingletonLvl(l)) {
assert(0 && "general singleton not supported yet");
} else if (isNOutOfMLvl(l)) {
- assert(0 && "2Out4 not supported yet");
+ assert(0 && "n ouf of m not supported yet");
} else {
assert(isDenseLvl(l));
}
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index 993ad9be8a012..022cc37615f13 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -40,7 +40,7 @@ FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
if (base.compare("block") == 0) {
ParseResult res = parser.parseCommaSeparatedList(
mlir::OpAsmParser::Delimiter::OptionalSquare,
- [&]() -> ParseResult { return parseBlockSize(parser, &blockSizes); },
+ [&]() -> ParseResult { return parseBlockSizes(parser, &blockSizes); },
" in block n out of m");
FAILURE_IF_FAILED(res)
}
@@ -95,8 +95,8 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
}
ParseResult
-LvlTypeParser::parseBlockSize(AsmParser &parser,
- SmallVector<unsigned> *blockSizes) const {
+LvlTypeParser::parseBlockSizes(AsmParser &parser,
+ SmallVector<unsigned> *blockSizes) const {
int intVal;
auto loc = parser.getCurrentLocation();
OptionalParseResult intValParseResult = parser.parseOptionalInteger(intVal);
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
index 78ae667f97923..250c98fb8702b 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
@@ -22,8 +22,8 @@ class LvlTypeParser {
private:
ParseResult parseProperty(AsmParser &parser, uint64_t *properties) const;
- ParseResult parseBlockSize(AsmParser &parser,
- SmallVector<unsigned> *blockSizes) const;
+ ParseResult parseBlockSizes(AsmParser &parser,
+ SmallVector<unsigned> *blockSizes) const;
};
} // namespace ir_detail
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index 966a7ff2d38e1..367b9c736095a 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -254,3 +254,21 @@ func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
return
}
+
+// -----
+
+#NOutOfM = #sparse_tensor.encoding<{
+ map = ( i, j, k ) ->
+ ( i : dense,
+ k floordiv 8 : dense,
+ j : dense,
+ k mod 8 : block[5, 8]
+ )
+}>
+
+// CHECK-DAG: #[[$NOutOfM:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d2 floordiv 8 : dense, d1 : dense, d2 mod 8 : block[5, 8]) }>
+// CHECK-LABEL: func private @NOutOfM(
+// CHECK-SAME: tensor<?x?x?xf64, #[[$NOutOfM]]>
+func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
+ return
+}
>From d8ee03088957ac4525eaf4117fce1aa9de1e747c Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Tue, 30 Jan 2024 23:08:20 +0000
Subject: [PATCH 05/13] address comments
---
mlir/include/mlir-c/Dialect/SparseTensor.h | 28 +++++------
.../mlir/Dialect/SparseTensor/IR/Enums.h | 46 +++++++++----------
.../SparseTensor/IR/SparseTensorAttrDefs.td | 3 +-
.../SparseTensor/IR/Detail/LvlTypeParser.cpp | 20 ++++----
.../SparseTensor/IR/Detail/LvlTypeParser.h | 4 +-
.../SparseTensor/GPU/gpu_matmul24_lib.mlir | 2 +-
.../SparseTensor/roundtrip_encoding.mlir | 16 +++----
.../SparseTensor/CPU/sparse_block_matmul.mlir | 2 +-
.../Dialect/SparseTensor/CPU/sparse_ds.mlir | 2 +-
.../CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir | 2 +-
.../CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir | 2 +-
11 files changed, 64 insertions(+), 63 deletions(-)
diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 947a746b60a65..2c71b0008ad16 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -28,20 +28,20 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor);
typedef uint64_t MlirSparseTensorLevelType;
enum MlirBaseSparseTensorLevelType {
- MLIR_SPARSE_TENSOR_LEVEL_DENSE = 65536, // 0x00_00_0001_0000
- MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 131072, // 0x00_00_0002_0000
- MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 131073, // 0x00_00_0002_0001
- MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 131074, // 0x00_00_0002_0002
- MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 131075, // 0x00_00_0002_0003
- MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 262144, // 0x00_00_0004_0000
- MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 262145, // 0x00_00_0004_0001
- MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 262146, // 0x00_00_0004_0002
- MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 262147, // 0x00_00_0004_0003
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 524288, // 0x00_00_0008_0000
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 524289, // 0x00_00_0008_0001
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 524290, // 0x00_00_0008_0002
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 524291, // 0x00_00_0008_0003
- MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 1048576, // 0x00_00_0010_0000
+ MLIR_SPARSE_TENSOR_LEVEL_DENSE = 0x000000010000,
+ MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 0x000000020000,
+ MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 0x000000020001,
+ MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 0x000000020002,
+ MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 0x000000020003,
+ MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 0x000000040000,
+ MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 0x000000040001,
+ MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 0x000000040002,
+ MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 0x000000040003,
+ MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 0x000000080000,
+ MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 0x000000080001,
+ MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 0x000000080002,
+ MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 0x000000080003,
+ MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000100000,
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 57e324412f540..a1eefdceae96f 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -172,36 +172,36 @@ enum class Action : uint32_t {
/// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty |
///
enum class LevelType : uint64_t {
- Undef = 0, // 0x00_00_0000_0000
- Dense = 65536, // 0x00_00_0001_0000
- Compressed = 131072, // 0x00_00_0002_0000
- CompressedNu = 131073, // 0x00_00_0002_0001
- CompressedNo = 131074, // 0x00_00_0002_0002
- CompressedNuNo = 131075, // 0x00_00_0002_0003
- Singleton = 262144, // 0x00_00_0004_0000
- SingletonNu = 262145, // 0x00_00_0004_0001
- SingletonNo = 262146, // 0x00_00_0004_0002
- SingletonNuNo = 262147, // 0x00_00_0004_0003
- LooseCompressed = 524288, // 0x00_00_0008_0000
- LooseCompressedNu = 524289, // 0x00_00_0008_0001
- LooseCompressedNo = 524290, // 0x00_00_0008_0002
- LooseCompressedNuNo = 524291, // 0x00_00_0008_0003
- NOutOfM = 1048576, // 0x00_00_0010_0000
+ Undef = 0x000000000000,
+ Dense = 0x000000010000,
+ Compressed = 0x000000020000,
+ CompressedNu = 0x000000020001,
+ CompressedNo = 0x000000020002,
+ CompressedNuNo = 0x000000020003,
+ Singleton = 0x000000040000,
+ SingletonNu = 0x000000040001,
+ SingletonNo = 0x000000040002,
+ SingletonNuNo = 0x000000040003,
+ LooseCompressed = 0x000000080000,
+ LooseCompressedNu = 0x000000080001,
+ LooseCompressedNo = 0x000000080002,
+ LooseCompressedNuNo = 0x000000080003,
+ NOutOfM = 0x000000100000,
};
/// This enum defines all supported storage format without the level properties.
enum class LevelFormat : uint64_t {
- Dense = 65536, // 0x0001_0000
- Compressed = 131072, // 0x0002_0000
- Singleton = 262144, // 0x0004_0000
- LooseCompressed = 524288, // 0x0008_0000
- NOutOfM = 1048576, // 0x0010_0000
+ Dense = 0x00010000,
+ Compressed = 0x00020000,
+ Singleton = 0x00040000,
+ LooseCompressed = 0x00080000,
+ NOutOfM = 0x00100000,
};
/// This enum defines all the nondefault properties for storage formats.
enum class LevelPropertyNondefault : uint64_t {
- Nonunique = 1, // 0x0001
- Nonordered = 2, // 0x0002
+ Nonunique = 0x0001,
+ Nonordered = 0x0002,
};
/// Get N of NOutOfM level type.
@@ -265,7 +265,7 @@ constexpr const char *toMLIRString(LevelType lt) {
return "loose_compressed(nonunique, nonordered)";
default:
if (isNOutOfMLT(lt)) {
- return "block";
+ return "structured";
}
}
return "";
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 08ba96d437045..5b3b971f9a7f2 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -145,7 +145,8 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
- **compressed** : only nonzeros along this level are stored
- **loose_compressed** : as compressed, but allows for free space between regions
- **singleton** : a variant of the compressed format, where coordinates have no siblings
- - **block[n, m]** : the compression uses a n:m encoding per 1xm block
+ - **structured[n, m]** : the compression uses a n:m encoding
+ (viz. n out of m consecutive elements are nonzero)
For a compressed level, each position interval is represented in a compact
way with a lowerbound `pos(i)` and an upperbound `pos(i+1) - 1`, which implies
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index 022cc37615f13..752d6e6481dfe 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -35,12 +35,12 @@ FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
ERROR_IF(failed(parser.parseOptionalKeyword(&base)),
"expected valid level format (e.g. dense, compressed or singleton)")
uint64_t properties = 0;
- SmallVector<unsigned> blockSizes;
+ SmallVector<unsigned> structure;
- if (base.compare("block") == 0) {
+ if (base.compare("structured") == 0) {
ParseResult res = parser.parseCommaSeparatedList(
mlir::OpAsmParser::Delimiter::OptionalSquare,
- [&]() -> ParseResult { return parseBlockSizes(parser, &blockSizes); },
+ [&]() -> ParseResult { return parseStructure(parser, &structure); },
" in block n out of m");
FAILURE_IF_FAILED(res)
}
@@ -56,13 +56,13 @@ FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
properties |= static_cast<uint64_t>(LevelFormat::Dense);
} else if (base.compare("compressed") == 0) {
properties |= static_cast<uint64_t>(LevelFormat::Compressed);
- } else if (base.compare("block") == 0) {
- if (blockSizes.size() != 2) {
- parser.emitError(loc, "expected exactly 2 block sizes");
+ } else if (base.compare("structured") == 0) {
+ if (structure.size() != 2) {
+ parser.emitError(loc, "expected exactly 2 structure sizes");
return failure();
}
properties |= static_cast<uint64_t>(LevelFormat::NOutOfM);
- properties |= nToBits(blockSizes[0]) | mToBits(blockSizes[1]);
+ properties |= nToBits(structure[0]) | mToBits(structure[1]);
} else if (base.compare("loose_compressed") == 0) {
properties |= static_cast<uint64_t>(LevelFormat::LooseCompressed);
} else if (base.compare("singleton") == 0) {
@@ -95,8 +95,8 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
}
ParseResult
-LvlTypeParser::parseBlockSizes(AsmParser &parser,
- SmallVector<unsigned> *blockSizes) const {
+LvlTypeParser::parseStructure(AsmParser &parser,
+ SmallVector<unsigned> *structure) const {
int intVal;
auto loc = parser.getCurrentLocation();
OptionalParseResult intValParseResult = parser.parseOptionalInteger(intVal);
@@ -105,7 +105,7 @@ LvlTypeParser::parseBlockSizes(AsmParser &parser,
parser.emitError(loc, "failed to parse block size");
return failure();
}
- blockSizes->push_back(intVal);
+ structure->push_back(intVal);
return success();
}
parser.emitError(loc, "expected valid integer for block size");
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
index 250c98fb8702b..6a13112195d44 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
@@ -22,8 +22,8 @@ class LvlTypeParser {
private:
ParseResult parseProperty(AsmParser &parser, uint64_t *properties) const;
- ParseResult parseBlockSizes(AsmParser &parser,
- SmallVector<unsigned> *blockSizes) const;
+ ParseResult parseStructure(AsmParser &parser,
+ SmallVector<unsigned> *structure) const;
};
} // namespace ir_detail
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir
index e4884a9bf393f..8293169049ca6 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir
@@ -4,7 +4,7 @@
map = ( i, j ) ->
( i : dense,
j floordiv 4 : dense,
- j mod 4 : block[2, 4]
+ j mod 4 : structured[2, 4]
)
}>
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index 367b9c736095a..64520638b253d 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -207,12 +207,12 @@ func.func private @BSR_explicit(%arg0: tensor<?x?xf64, #BSR_explicit>) {
map = ( i, j ) ->
( i : dense,
j floordiv 4 : dense,
- j mod 4 : block[2, 4]
+ j mod 4 : structured[2, 4]
),
crdWidth = 8 // we would even like just 2-bits
}>
-// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : block[2, 4]), crdWidth = 8 }>
+// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : structured[2, 4]), crdWidth = 8 }>
// CHECK-LABEL: func private @NV_24(
// CHECK-SAME: tensor<?x?xf64, #[[$NV_24]]>
func.func private @NV_24(%arg0: tensor<?x?xf64, #NV_24>) {
@@ -226,11 +226,11 @@ func.func private @NV_24(%arg0: tensor<?x?xf64, #NV_24>) {
( i : dense,
j : dense,
k floordiv 4 : dense,
- k mod 4 : block[2, 4]
+ k mod 4 : structured[2, 4]
)
}>
-// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 floordiv 4 : dense, d2 mod 4 : block[2, 4]) }>
+// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 floordiv 4 : dense, d2 mod 4 : structured[2, 4]) }>
// CHECK-LABEL: func private @NV_24(
// CHECK-SAME: tensor<?x?x?xf64, #[[$NV_24]]>
func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
@@ -244,11 +244,11 @@ func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
( i : dense,
k floordiv 4 : dense,
j : dense,
- k mod 4 : block[2, 4]
+ k mod 4 : structured[2, 4]
)
}>
-// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d2 floordiv 4 : dense, d1 : dense, d2 mod 4 : block[2, 4]) }>
+// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d2 floordiv 4 : dense, d1 : dense, d2 mod 4 : structured[2, 4]) }>
// CHECK-LABEL: func private @NV_24(
// CHECK-SAME: tensor<?x?x?xf64, #[[$NV_24]]>
func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
@@ -262,11 +262,11 @@ func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
( i : dense,
k floordiv 8 : dense,
j : dense,
- k mod 8 : block[5, 8]
+ k mod 8 : structured[5, 8]
)
}>
-// CHECK-DAG: #[[$NOutOfM:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d2 floordiv 8 : dense, d1 : dense, d2 mod 8 : block[5, 8]) }>
+// CHECK-DAG: #[[$NOutOfM:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d2 floordiv 8 : dense, d1 : dense, d2 mod 8 : structured[5, 8]) }>
// CHECK-LABEL: func private @NOutOfM(
// CHECK-SAME: tensor<?x?x?xf64, #[[$NOutOfM]]>
func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
index 554d6207aef7e..e47ac46597b77 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
@@ -59,7 +59,7 @@
map = ( i, j ) ->
( i : dense,
j floordiv 4 : dense,
- j mod 4 : block[2, 4]
+ j mod 4 : structured[2, 4]
),
}>
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir
index 9935d7c69e63a..ec5c7580657cd 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir
@@ -41,7 +41,7 @@
#NV_24 = #sparse_tensor.encoding<{
map = ( i, j ) -> ( i : dense,
j floordiv 4 : dense,
- j mod 4 : block[2, 4]),
+ j mod 4 : structured[2, 4]),
crdWidth = 8
}>
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
index 25454f5c06b45..b0f63f12c2d57 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
@@ -20,7 +20,7 @@
map = ( i, j ) ->
( i : dense,
j floordiv 4 : dense,
- j mod 4 : block[2, 4]
+ j mod 4 : structured[2, 4]
)
}>
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
index da735b4a3b58a..311cb607b4293 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
@@ -20,7 +20,7 @@
map = ( i, j ) ->
( i : dense,
j floordiv 4 : dense,
- j mod 4 : block[2, 4]
+ j mod 4 : structured[2, 4]
)
}>
>From f6891fedf1c41033541671608b56753d81164e4d Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 1 Feb 2024 05:49:08 +0000
Subject: [PATCH 06/13] ensure 64 bits
---
mlir/include/mlir-c/Dialect/SparseTensor.h | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 2c71b0008ad16..72e6c7ff1f950 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -12,6 +12,7 @@
#include "mlir-c/AffineMap.h"
#include "mlir-c/IR.h"
+#include <assert.h>
#ifdef __cplusplus
extern "C" {
@@ -44,6 +45,11 @@ enum MlirBaseSparseTensorLevelType {
MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000100000,
};
+static_assert((sizeof(enum MlirSparseTensorLevelType) == 8),
+ "MlirSparseTensorLevelType must be 8 bytes");
+
+#pragma GCC diagnostic pop
+
//===----------------------------------------------------------------------===//
// SparseTensorEncodingAttr
//===----------------------------------------------------------------------===//
>From 97d4351f740538ffeab3d0becba11a35262504d8 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 1 Feb 2024 06:24:56 +0000
Subject: [PATCH 07/13] windows
---
mlir/include/mlir-c/Dialect/SparseTensor.h | 3 +++
1 file changed, 3 insertions(+)
diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 72e6c7ff1f950..d98b3161de780 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -43,12 +43,15 @@ enum MlirBaseSparseTensorLevelType {
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 0x000000080002,
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 0x000000080003,
MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000100000,
+ MLIR_SPARSE_TENSOR_LEVEL_2_OUT_OF_4 = 0x020400100000,
};
static_assert((sizeof(enum MlirSparseTensorLevelType) == 8),
"MlirSparseTensorLevelType must be 8 bytes");
+#if !defined(_MSC_VER)
#pragma GCC diagnostic pop
+#endif
//===----------------------------------------------------------------------===//
// SparseTensorEncodingAttr
>From cd93c6e5b874f25a1bfd0d6b9916a57792045f99 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 1 Feb 2024 07:12:14 +0000
Subject: [PATCH 08/13] windows set enum width
---
mlir/include/mlir-c/Dialect/SparseTensor.h | 8 --------
1 file changed, 8 deletions(-)
diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index d98b3161de780..b6a94064920ad 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -43,16 +43,8 @@ enum MlirBaseSparseTensorLevelType {
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 0x000000080002,
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 0x000000080003,
MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000100000,
- MLIR_SPARSE_TENSOR_LEVEL_2_OUT_OF_4 = 0x020400100000,
};
-static_assert((sizeof(enum MlirSparseTensorLevelType) == 8),
- "MlirSparseTensorLevelType must be 8 bytes");
-
-#if !defined(_MSC_VER)
-#pragma GCC diagnostic pop
-#endif
-
//===----------------------------------------------------------------------===//
// SparseTensorEncodingAttr
//===----------------------------------------------------------------------===//
>From f7033350bd485e075dd18aee1cb23f2e6aa5a4fa Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 1 Feb 2024 07:58:51 +0000
Subject: [PATCH 09/13] use typedef
---
mlir/include/mlir-c/Dialect/SparseTensor.h | 1 -
mlir/lib/CAPI/Dialect/SparseTensor.cpp | 16 ++++++++--------
2 files changed, 8 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index b6a94064920ad..2c71b0008ad16 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -12,7 +12,6 @@
#include "mlir-c/AffineMap.h"
#include "mlir-c/IR.h"
-#include <assert.h>
#ifdef __cplusplus
extern "C" {
diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
index a34b9a29b0e90..24e41c037a8dd 100644
--- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp
+++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
@@ -55,11 +55,11 @@ bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
return isa<SparseTensorEncodingAttr>(unwrap(attr));
}
-MlirAttribute
-mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank,
- MlirSparseTensorLevelType const *lvlTypes,
- MlirAffineMap dimToLvl, MlirAffineMap lvlToDim,
- int posWidth, int crdWidth) {
+MlirAttribute mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank,
+ level_type const *lvlTypes,
+ MlirAffineMap dimToLvl,
+ MlirAffineMap lvlToDim,
+ int posWidth, int crdWidth) {
SmallVector<LevelType> cppLvlTypes;
cppLvlTypes.reserve(lvlRank);
for (intptr_t l = 0; l < lvlRank; ++l)
@@ -81,9 +81,9 @@ intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) {
return cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlRank();
}
-MlirSparseTensorLevelType
-mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl) {
- return static_cast<MlirSparseTensorLevelType>(
+level_type mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr,
+ intptr_t lvl) {
+ return static_cast<level_type>(
cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlType(lvl));
}
>From 1125ce7c6689fb9af0deb6dea574d1b36f9ecc90 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 1 Feb 2024 09:05:22 +0000
Subject: [PATCH 10/13] level_type for pybind
---
mlir/lib/Bindings/Python/DialectSparseTensor.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index 607534c615643..90fe12d1165c3 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -46,7 +46,7 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
mlirAttributeIsASparseTensorEncodingAttr)
.def_classmethod(
"get",
- [](py::object cls, std::vector<MlirSparseTensorLevelType> lvlTypes,
+ [](py::object cls, std::vector<level_type> lvlTypes,
std::optional<MlirAffineMap> dimToLvl,
std::optional<MlirAffineMap> lvlToDim, int posWidth, int crdWidth,
MlirContext context) {
@@ -64,7 +64,7 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
"lvl_types",
[](MlirAttribute self) {
const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
- std::vector<MlirSparseTensorLevelType> ret;
+ std::vector<level_type> ret;
ret.reserve(lvlRank);
for (int l = 0; l < lvlRank; ++l)
ret.push_back(mlirSparseTensorEncodingAttrGetLvlType(self, l));
>From 7273615dfceb4637a399e3e78b83a10d32acf63d Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Fri, 2 Feb 2024 21:49:56 +0000
Subject: [PATCH 11/13] address comments
---
mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index 14809381517e7..14182172f4f62 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -722,7 +722,7 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
coordinates[l].reserve(sz);
sz = 1;
} else if (isNOutOfMLvl(l)) {
- assert(l == lvlRank - 1 && "unexpected 2:4 usage");
+ assert(l == lvlRank - 1 && "unexpected n:m usage");
sz = detail::checkedMul(sz, lvlSizes[l]) / 2;
coordinates[l].reserve(sz);
values.reserve(sz);
>From cff40fa127150550e2541b22bfe7efc2bc0b7ecd Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Mon, 5 Feb 2024 23:13:30 +0000
Subject: [PATCH 12/13] n:m without enum64 changes
---
mlir/lib/Bindings/Python/DialectSparseTensor.cpp | 4 ++--
mlir/lib/CAPI/Dialect/SparseTensor.cpp | 16 ++++++++--------
2 files changed, 10 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index 90fe12d1165c3..607534c615643 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -46,7 +46,7 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
mlirAttributeIsASparseTensorEncodingAttr)
.def_classmethod(
"get",
- [](py::object cls, std::vector<level_type> lvlTypes,
+ [](py::object cls, std::vector<MlirSparseTensorLevelType> lvlTypes,
std::optional<MlirAffineMap> dimToLvl,
std::optional<MlirAffineMap> lvlToDim, int posWidth, int crdWidth,
MlirContext context) {
@@ -64,7 +64,7 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
"lvl_types",
[](MlirAttribute self) {
const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
- std::vector<level_type> ret;
+ std::vector<MlirSparseTensorLevelType> ret;
ret.reserve(lvlRank);
for (int l = 0; l < lvlRank; ++l)
ret.push_back(mlirSparseTensorEncodingAttrGetLvlType(self, l));
diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
index 24e41c037a8dd..a34b9a29b0e90 100644
--- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp
+++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
@@ -55,11 +55,11 @@ bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
return isa<SparseTensorEncodingAttr>(unwrap(attr));
}
-MlirAttribute mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank,
- level_type const *lvlTypes,
- MlirAffineMap dimToLvl,
- MlirAffineMap lvlToDim,
- int posWidth, int crdWidth) {
+MlirAttribute
+mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank,
+ MlirSparseTensorLevelType const *lvlTypes,
+ MlirAffineMap dimToLvl, MlirAffineMap lvlToDim,
+ int posWidth, int crdWidth) {
SmallVector<LevelType> cppLvlTypes;
cppLvlTypes.reserve(lvlRank);
for (intptr_t l = 0; l < lvlRank; ++l)
@@ -81,9 +81,9 @@ intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) {
return cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlRank();
}
-level_type mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr,
- intptr_t lvl) {
- return static_cast<level_type>(
+MlirSparseTensorLevelType
+mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl) {
+ return static_cast<MlirSparseTensorLevelType>(
cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlType(lvl));
}
>From e32d3b6f42570db85cf9e6cbff5b48d84501ccd0 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Tue, 6 Feb 2024 18:16:01 +0000
Subject: [PATCH 13/13] toMLIRString
---
mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h | 9 ++++-----
1 file changed, 4 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index a1eefdceae96f..d17dbb3e57f04 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -233,7 +233,8 @@ constexpr bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) {
}
/// Returns string representation of the given dimension level type.
-constexpr const char *toMLIRString(LevelType lt) {
+constexpr const char *toMLIRString(LevelType lvlType) {
+ auto lt = static_cast<LevelType>(static_cast<uint64_t>(lvlType) & 0xffffffff);
switch (lt) {
case LevelType::Undef:
return "undef";
@@ -263,10 +264,8 @@ constexpr const char *toMLIRString(LevelType lt) {
return "loose_compressed(nonordered)";
case LevelType::LooseCompressedNuNo:
return "loose_compressed(nonunique, nonordered)";
- default:
- if (isNOutOfMLT(lt)) {
- return "structured";
- }
+ case LevelType::NOutOfM:
+ return "structured";
}
return "";
}
More information about the cfe-commits
mailing list