[Mlir-commits] [mlir] [mlir][sparse] Change DimLevelType Enum (PR #67192)
Yinying Li
llvmlistbot at llvm.org
Fri Sep 22 13:42:43 PDT 2023
https://github.com/yinying-lisa-li created https://github.com/llvm/llvm-project/pull/67192
Update DimLevelType to use lower 8 bits for storage formats and the higher 4 bits to store level properties. Treat CompressedWithHi and TwoOutOfFour as properties instead of formats.
Example: Compressed is 0b00000000_00000010 and CompressedWithHi is 0b00000100_00000010. It indicates that CompressedWithHi's format is Compressed and its property is High.
>From ebf4376df1481266c589a4c3967a94d738ba0fc3 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Fri, 22 Sep 2023 19:43:38 +0000
Subject: [PATCH 1/2] [mlir][sparse] Change DimLevelType Enum
Update DimLevelType to use lower 8 bits for storage formats and the higher 4 bits to store level properties. Treat CompressedWithHi and TwoOutOfFour as properties instead of formats.
Example: Compressed is 0b00000000_00000010 and CompressedWithHi is 0b00000100_00000010. It indicates that CompressedWithHi's format is Compressed and its property is High.
---
mlir/include/mlir-c/Dialect/SparseTensor.h | 32 +-
.../mlir/Dialect/SparseTensor/IR/Enums.h | 275 ++++++++++++------
.../ExecutionEngine/SparseTensor/Storage.h | 2 +-
mlir/lib/CAPI/Dialect/SparseTensor.cpp | 36 +--
.../SparseTensor/IR/Detail/LvlTypeParser.cpp | 24 +-
.../SparseTensor/IR/Detail/LvlTypeParser.h | 4 +-
.../SparseTensor/IR/SparseTensorDialect.cpp | 17 +-
.../SparseTensor/Transforms/CodegenUtils.h | 2 +-
mlir/lib/ExecutionEngine/SparseTensor/NNZ.cpp | 2 +-
.../ExecutionEngine/SparseTensor/Storage.cpp | 2 +-
mlir/test/CAPI/sparse_tensor.c | 6 +-
.../test/Dialect/SparseTensor/conversion.mlir | 16 +-
.../SparseTensor/convert_dense2sparse.mlir | 16 +-
.../SparseTensor/convert_sparse2dense.mlir | 68 ++---
.../SparseTensor/convert_sparse2sparse.mlir | 20 +-
.../Dialect/SparseTensor/sparse_concat.mlir | 86 +++---
.../SparseTensor/sparse_fill_zero.mlir | 12 +-
.../python/dialects/sparse_tensor/dialect.py | 4 +-
18 files changed, 365 insertions(+), 259 deletions(-)
diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index b2e4b96c65019c5..c0676b44d34a0e6 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -26,20 +26,24 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor);
/// If updating, keep them in sync and update the static_assert in the impl
/// file.
enum MlirSparseTensorDimLevelType {
- MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE = 4, // 0b00001_00
- MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED = 8, // 0b00010_00
- MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU = 9, // 0b00010_01
- MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO = 10, // 0b00010_10
- MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO = 11, // 0b00010_11
- MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON = 16, // 0b00100_00
- MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU = 17, // 0b00100_01
- MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO = 18, // 0b00100_10
- MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO = 19, // 0b00100_11
- MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI = 32, // 0b01000_00
- MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NU = 33, // 0b01000_01
- MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NO = 34, // 0b01000_10
- MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NU_NO = 35, // 0b01000_11
- MLIR_SPARSE_TENSOR_DIM_LEVEL_TWO_OUT_OF_FOUR = 64, // 0b10000_00
+ MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE = 1, // 0b00000000_00000001
+ MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED = 2, // 0b00000000_00000010
+ MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU = 258, // 0b00000001_00000010
+ MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO = 514, // 0b00000010_00000010
+ MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO = 770, // 0b00000011_00000010
+ MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON = 4, // 0b00000000_00000100
+ MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU = 260, // 0b00000001_00000100
+ MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO = 516, // 0b00000010_00000100
+ MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO = 772, // 0b00000011_00000100
+ MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI = 1026,
+ // 0b00000100_00000010
+ MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NU = 1282,
+ // 0b00000101_00000010
+ MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NO = 1538,
+ // 0b00000110_00000010
+ MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NU_NO = 1794,
+ // 0b00000111_00000010
+ MLIR_SPARSE_TENSOR_DIM_LEVEL_TWO_OUT_OF_FOUR = 2050, // 0b00001000_00000010
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index ea0d9e2d43b74c7..3e24b0be3069d06 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -156,9 +156,10 @@ enum class Action : uint32_t {
/// This enum defines all the sparse representations supportable by
/// the SparseTensor dialect. We use a lightweight encoding to encode
/// both the "format" per se (dense, compressed, singleton) as well as
-/// the "properties" (ordered, unique). The encoding is chosen for
-/// performance of the runtime library, and thus may change in future
-/// versions; consequently, client code should use the predicate functions
+/// the "properties" (ordered, unique, high, 2outof4).
+/// The encoding is chosen for performance of the runtime library,
+/// and thus may change in future versions;
+/// consequently, client code should use the predicate functions
/// defined below, rather than relying on knowledge about the particular
/// binary encoding.
///
@@ -169,42 +170,42 @@ enum class Action : uint32_t {
///
// TODO: We should generalize TwoOutOfFour to N out of M and use property to
// encode the value of N and M.
-// TODO: Update DimLevelType to use lower 8 bits for storage formats and the
-// higher 4 bits to store level properties. Consider CompressedWithHi and
-// TwoOutOfFour as properties instead of formats.
-enum class DimLevelType : uint8_t {
- Undef = 0, // 0b00000_00
- Dense = 4, // 0b00001_00
- Compressed = 8, // 0b00010_00
- CompressedNu = 9, // 0b00010_01
- CompressedNo = 10, // 0b00010_10
- CompressedNuNo = 11, // 0b00010_11
- Singleton = 16, // 0b00100_00
- SingletonNu = 17, // 0b00100_01
- SingletonNo = 18, // 0b00100_10
- SingletonNuNo = 19, // 0b00100_11
- CompressedWithHi = 32, // 0b01000_00
- CompressedWithHiNu = 33, // 0b01000_01
- CompressedWithHiNo = 34, // 0b01000_10
- CompressedWithHiNuNo = 35, // 0b01000_11
- TwoOutOfFour = 64, // 0b10000_00
+enum class DimLevelType : uint16_t {
+ Undef = 0, // 0b00000000_00000000
+ Dense = 1, // 0b00000000_00000001
+ Compressed = 2, // 0b00000000_00000010
+ CompressedNu = 258, // 0b00000001_00000010
+ CompressedNo = 514, // 0b00000010_00000010
+ CompressedNuNo = 770, // 0b00000011_00000010
+ Singleton = 4, // 0b00000000_00000100
+ SingletonNu = 260, // 0b00000001_00000100
+ SingletonNo = 516, // 0b00000010_00000100
+ SingletonNuNo = 772, // 0b00000011_00000100
+ CompressedWithHi = 1026, // 0b00000100_00000010
+ CompressedWithHiNu = 1282, // 0b00000101_00000010
+ CompressedWithHiNo = 1538, // 0b00000110_00000010
+ CompressedWithHiNuNo = 1794, // 0b00000111_00000010
+ TwoOutOfFour = 2050, // 0b00001000_00000010
};
-/// This enum defines all supported storage format without the level properties.
-enum class LevelFormat : uint8_t {
- Dense = 4, // 0b00001_00
- Compressed = 8, // 0b00010_00
- Singleton = 16, // 0b00100_00
- CompressedWithHi = 32, // 0b01000_00
- TwoOutOfFour = 64, // 0b10000_00
+/// This enum defines all the storage formats supported by the sparse compiler,
+/// without the level properties.
+enum class LevelFormat : uint16_t {
+ Dense = 1, // 0b00000000_00000001
+ Compressed = 2, // 0b00000000_00000010
+ Singleton = 4, // 0b00000000_00000100
+ // TODO: Remove CompressedWithHi and TwoOutOfFour from LevelFormat
+ // once internal change lands.
+ CompressedWithHi = 1026, // 0b00000100_00000010
+ TwoOutOfFour = 2050, // 0b00001000_00000010
};
/// This enum defines all the nondefault properties for storage formats.
-enum class LevelNondefaultProperty : uint8_t {
- Nonunique = 1, // 0b00000_01
- Nonordered = 2, // 0b00000_10
- High = 32, // 0b01000_00
- Block2_4 = 64 // 0b10000_00
+enum class LevelNondefaultProperty : uint16_t {
+ Nonunique = 256, // 0b00000001_00000000
+ Nonordered = 512, // 0b00000010_00000000
+ High = 1024, // 0b00000100_00000000
+ Block2_4 = 2048 // 0b00001000_00000000
};
/// Returns string representation of the given dimension level type.
@@ -246,13 +247,12 @@ constexpr const char *toMLIRString(DimLevelType dlt) {
/// Check that the `DimLevelType` contains a valid (possibly undefined) value.
constexpr bool isValidDLT(DimLevelType dlt) {
- const uint8_t formatBits = static_cast<uint8_t>(dlt) >> 2;
- const uint8_t propertyBits = static_cast<uint8_t>(dlt) & 3;
+ const uint16_t formatBits = static_cast<uint16_t>(dlt) & 255;
+ const uint16_t propertyBits = static_cast<uint16_t>(dlt) >> 8;
// If undefined or dense, then must be unique and ordered.
// Otherwise, the format must be one of the known ones.
- return (formatBits <= 1 || formatBits == 16)
- ? (propertyBits == 0)
- : (formatBits == 2 || formatBits == 4 || formatBits == 8);
+ return (formatBits <= 1) ? (propertyBits == 0)
+ : (formatBits == 2 || formatBits == 4);
}
/// Check if the `DimLevelType` is the special undefined value.
@@ -270,36 +270,49 @@ constexpr bool isTwoOutOfFourDLT(DimLevelType dlt) {
return dlt == DimLevelType::TwoOutOfFour;
}
-// We use the idiom `(dlt & ~3) == format` in order to only return true
+// We use the idiom `(dlt & 255) == format` in order to only return true
// for valid DLTs. Whereas the `dlt & format` idiom is a bit faster but
// can return false-positives on invalid DLTs.
/// Check if the `DimLevelType` is compressed (regardless of properties).
constexpr bool isCompressedDLT(DimLevelType dlt) {
- return (static_cast<uint8_t>(dlt) & ~3) ==
- static_cast<uint8_t>(DimLevelType::Compressed);
-}
-
-/// Check if the `DimLevelType` is compressed (regardless of properties).
-constexpr bool isCompressedWithHiDLT(DimLevelType dlt) {
- return (static_cast<uint8_t>(dlt) & ~3) ==
- static_cast<uint8_t>(DimLevelType::CompressedWithHi);
+ return (static_cast<uint16_t>(dlt) & 255) ==
+ static_cast<uint16_t>(DimLevelType::Compressed);
}
/// Check if the `DimLevelType` is singleton (regardless of properties).
constexpr bool isSingletonDLT(DimLevelType dlt) {
- return (static_cast<uint8_t>(dlt) & ~3) ==
- static_cast<uint8_t>(DimLevelType::Singleton);
+ return (static_cast<uint16_t>(dlt) & 255) ==
+ static_cast<uint16_t>(DimLevelType::Singleton);
}
/// Check if the `DimLevelType` is ordered (regardless of storage format).
constexpr bool isOrderedDLT(DimLevelType dlt) {
- return !(static_cast<uint8_t>(dlt) & 2);
+ return !(static_cast<uint16_t>(dlt) &
+ static_cast<uint16_t>(LevelNondefaultProperty::Nonordered));
}
/// Check if the `DimLevelType` is unique (regardless of storage format).
constexpr bool isUniqueDLT(DimLevelType dlt) {
- return !(static_cast<uint8_t>(dlt) & 1);
+ return !(static_cast<uint16_t>(dlt) &
+ static_cast<uint16_t>(LevelNondefaultProperty::Nonunique));
+}
+
+/// Check if the `DimLevelType` is high (regardless of storage format).
+constexpr bool isHighDLT(DimLevelType dlt) {
+ return static_cast<uint16_t>(dlt) &
+ static_cast<uint16_t>(LevelNondefaultProperty::High);
+}
+
+/// Check if the `DimLevelType` is 2outof4 (regardless of storage format).
+constexpr bool isBlockTwoOutOfFourDLT(DimLevelType dlt) {
+ return static_cast<uint16_t>(dlt) &
+ static_cast<uint16_t>(LevelNondefaultProperty::Block2_4);
+}
+
+/// Check if the `DimLevelType` is compressed with high.
+constexpr bool isCompressedWithHiDLT(DimLevelType dlt) {
+ return (isCompressedDLT(dlt) && isHighDLT(dlt));
}
/// Convert a DimLevelType to its corresponding LevelFormat.
@@ -307,18 +320,48 @@ constexpr bool isUniqueDLT(DimLevelType dlt) {
constexpr std::optional<LevelFormat> getLevelFormat(DimLevelType dlt) {
if (dlt == DimLevelType::Undef)
return std::nullopt;
- return static_cast<LevelFormat>(static_cast<uint8_t>(dlt) & ~3);
+ // TODO: Remove this once internal change lands.
+ if (isHighDLT(dlt))
+ return LevelFormat::CompressedWithHi;
+ return static_cast<LevelFormat>(static_cast<uint16_t>(dlt) & 255);
}
-/// Convert a LevelFormat to its corresponding DimLevelType with the given
-/// properties. Returns std::nullopt when the properties are not applicable for
-/// the input level format.
-/// TODO: factor out a new LevelProperties type so we can add new properties
-/// without changing this function's signature
+/// Temporary support for the old version of passing booleans as parameters.
+// TODO: Remove this once internal change lands.
constexpr std::optional<DimLevelType>
buildLevelType(LevelFormat lf, bool ordered, bool unique) {
- auto dlt = static_cast<DimLevelType>(static_cast<uint8_t>(lf) |
- (ordered ? 0 : 2) | (unique ? 0 : 1));
+ auto format = static_cast<uint16_t>(lf);
+ if (!ordered)
+ format |= static_cast<uint16_t>(LevelNondefaultProperty::Nonordered);
+ if (!unique)
+ format |= static_cast<uint16_t>(LevelNondefaultProperty::Nonunique);
+ auto dlt = static_cast<DimLevelType>(format);
+ return isValidDLT(dlt) ? std::optional(dlt) : std::nullopt;
+}
+
+/// Helper function to convert booleans of level properties into property bits
+/// that could be used in buildLevelType.
+constexpr uint16_t toPropertyBits(bool ordered, bool unique, bool high = false,
+ bool block2_4 = false) {
+ uint16_t propertyBits = 0;
+ if (!ordered)
+ propertyBits |= static_cast<uint16_t>(LevelNondefaultProperty::Nonordered);
+ if (!unique)
+ propertyBits |= static_cast<uint16_t>(LevelNondefaultProperty::Nonunique);
+ if (high)
+ propertyBits |= static_cast<uint16_t>(LevelNondefaultProperty::High);
+ if (block2_4)
+ propertyBits |= static_cast<uint16_t>(LevelNondefaultProperty::Block2_4);
+ return propertyBits;
+}
+
+/// Convert a LevelFormat to its corresponding DimLevelType with the given
+/// nondefault properties. Returns std::nullopt when the properties are not
+/// applicable for the input level format.
+constexpr std::optional<DimLevelType> buildLevelType(LevelFormat lf,
+ uint16_t propertyBits) {
+ auto dlt =
+ static_cast<DimLevelType>(static_cast<uint16_t>(lf) | propertyBits);
return isValidDLT(dlt) ? std::optional(dlt) : std::nullopt;
}
@@ -337,32 +380,56 @@ static_assert(
"getLevelFormat conversion is broken");
static_assert(
- (buildLevelType(LevelFormat::Dense, false, true) == std::nullopt &&
- buildLevelType(LevelFormat::Dense, true, false) == std::nullopt &&
- buildLevelType(LevelFormat::Dense, false, false) == std::nullopt &&
- *buildLevelType(LevelFormat::Dense, true, true) == DimLevelType::Dense &&
- buildLevelType(LevelFormat::TwoOutOfFour, false, true) == std::nullopt &&
- buildLevelType(LevelFormat::TwoOutOfFour, true, false) == std::nullopt &&
- buildLevelType(LevelFormat::TwoOutOfFour, false, false) == std::nullopt &&
- *buildLevelType(LevelFormat::TwoOutOfFour, true, true) ==
- DimLevelType::TwoOutOfFour &&
- *buildLevelType(LevelFormat::Compressed, true, true) ==
- DimLevelType::Compressed &&
- *buildLevelType(LevelFormat::Compressed, true, false) ==
- DimLevelType::CompressedNu &&
- *buildLevelType(LevelFormat::Compressed, false, true) ==
- DimLevelType::CompressedNo &&
- *buildLevelType(LevelFormat::Compressed, false, false) ==
- DimLevelType::CompressedNuNo &&
- *buildLevelType(LevelFormat::Singleton, true, true) ==
- DimLevelType::Singleton &&
- *buildLevelType(LevelFormat::Singleton, true, false) ==
- DimLevelType::SingletonNu &&
- *buildLevelType(LevelFormat::Singleton, false, true) ==
- DimLevelType::SingletonNo &&
- *buildLevelType(LevelFormat::Singleton, false, false) ==
- DimLevelType::SingletonNuNo),
- "buildLevelType conversion is broken");
+ ((static_cast<uint16_t>(LevelFormat::Compressed) &
+ static_cast<uint16_t>(LevelFormat::Singleton) &
+ static_cast<uint16_t>(LevelFormat::Dense) &
+ static_cast<uint16_t>(LevelNondefaultProperty::Nonordered) &
+ static_cast<uint16_t>(LevelNondefaultProperty::Nonunique) &
+ static_cast<uint16_t>(LevelNondefaultProperty::High) &
+ static_cast<uint16_t>(LevelNondefaultProperty::Block2_4)) == 0),
+ "unique bit assignment for each level format and property is broken");
+
+static_assert((buildLevelType(LevelFormat::Dense,
+ toPropertyBits(false, true)) == std::nullopt &&
+ buildLevelType(LevelFormat::Dense,
+ toPropertyBits(true, false)) == std::nullopt &&
+ buildLevelType(LevelFormat::Dense,
+ toPropertyBits(false, false)) == std::nullopt &&
+ *buildLevelType(LevelFormat::Dense,
+ toPropertyBits(true, true)) ==
+ DimLevelType::Dense &&
+ *buildLevelType(LevelFormat::Compressed, true, true) ==
+ DimLevelType::Compressed &&
+ *buildLevelType(LevelFormat::Compressed, true, false) ==
+ DimLevelType::CompressedNu &&
+ *buildLevelType(LevelFormat::Compressed, false, true) ==
+ DimLevelType::CompressedNo &&
+ *buildLevelType(LevelFormat::Compressed, false, false) ==
+ DimLevelType::CompressedNuNo &&
+ *buildLevelType(LevelFormat::Compressed,
+ toPropertyBits(true, true, true)) ==
+ DimLevelType::CompressedWithHi &&
+ *buildLevelType(LevelFormat::Compressed,
+ toPropertyBits(false, true, true)) ==
+ DimLevelType::CompressedWithHiNo &&
+ *buildLevelType(LevelFormat::Compressed,
+ toPropertyBits(true, false, true)) ==
+ DimLevelType::CompressedWithHiNu &&
+ *buildLevelType(LevelFormat::Compressed,
+ toPropertyBits(false, false, true)) ==
+ DimLevelType::CompressedWithHiNuNo &&
+ *buildLevelType(LevelFormat::Compressed,
+ toPropertyBits(true, true, false, true)) ==
+ DimLevelType::TwoOutOfFour &&
+ *buildLevelType(LevelFormat::Singleton, true, true) ==
+ DimLevelType::Singleton &&
+ *buildLevelType(LevelFormat::Singleton, true, false) ==
+ DimLevelType::SingletonNu &&
+ *buildLevelType(LevelFormat::Singleton, false, true) ==
+ DimLevelType::SingletonNo &&
+ *buildLevelType(LevelFormat::Singleton, false, false) ==
+ DimLevelType::SingletonNuNo),
+ "buildLevelType conversion is broken");
// Ensure the above predicates work as intended.
static_assert((isValidDLT(DimLevelType::Undef) &&
@@ -387,6 +454,11 @@ static_assert((!isCompressedDLT(DimLevelType::Dense) &&
isCompressedDLT(DimLevelType::CompressedNu) &&
isCompressedDLT(DimLevelType::CompressedNo) &&
isCompressedDLT(DimLevelType::CompressedNuNo) &&
+ isCompressedDLT(DimLevelType::CompressedWithHi) &&
+ isCompressedDLT(DimLevelType::CompressedWithHiNu) &&
+ isCompressedDLT(DimLevelType::CompressedWithHiNo) &&
+ isCompressedDLT(DimLevelType::CompressedWithHiNuNo) &&
+ isCompressedDLT(DimLevelType::TwoOutOfFour) &&
!isCompressedDLT(DimLevelType::Singleton) &&
!isCompressedDLT(DimLevelType::SingletonNu) &&
!isCompressedDLT(DimLevelType::SingletonNo) &&
@@ -447,6 +519,37 @@ static_assert((isUniqueDLT(DimLevelType::Dense) &&
!isUniqueDLT(DimLevelType::CompressedWithHiNuNo)),
"isUniqueDLT definition is broken");
+static_assert((!isHighDLT(DimLevelType::Dense) &&
+ !isHighDLT(DimLevelType::TwoOutOfFour) &&
+ !isHighDLT(DimLevelType::Compressed) &&
+ !isHighDLT(DimLevelType::CompressedNu) &&
+ !isHighDLT(DimLevelType::CompressedNo) &&
+ !isHighDLT(DimLevelType::CompressedNuNo) &&
+ !isHighDLT(DimLevelType::Singleton) &&
+ !isHighDLT(DimLevelType::SingletonNu) &&
+ !isHighDLT(DimLevelType::SingletonNo) &&
+ !isHighDLT(DimLevelType::SingletonNuNo) &&
+ isHighDLT(DimLevelType::CompressedWithHi) &&
+ isHighDLT(DimLevelType::CompressedWithHiNu) &&
+ isHighDLT(DimLevelType::CompressedWithHiNo) &&
+ isHighDLT(DimLevelType::CompressedWithHiNuNo)),
+ "isHighDLT definition is broken");
+
+static_assert((!isBlockTwoOutOfFourDLT(DimLevelType::Dense) &&
+ isBlockTwoOutOfFourDLT(DimLevelType::TwoOutOfFour) &&
+ !isBlockTwoOutOfFourDLT(DimLevelType::Compressed) &&
+ !isBlockTwoOutOfFourDLT(DimLevelType::CompressedNu) &&
+ !isBlockTwoOutOfFourDLT(DimLevelType::CompressedNo) &&
+ !isBlockTwoOutOfFourDLT(DimLevelType::CompressedNuNo) &&
+ !isBlockTwoOutOfFourDLT(DimLevelType::Singleton) &&
+ !isBlockTwoOutOfFourDLT(DimLevelType::SingletonNu) &&
+ !isBlockTwoOutOfFourDLT(DimLevelType::SingletonNo) &&
+ !isBlockTwoOutOfFourDLT(DimLevelType::SingletonNuNo) &&
+ !isBlockTwoOutOfFourDLT(DimLevelType::CompressedWithHi) &&
+ !isBlockTwoOutOfFourDLT(DimLevelType::CompressedWithHiNu) &&
+ !isBlockTwoOutOfFourDLT(DimLevelType::CompressedWithHiNo) &&
+ !isBlockTwoOutOfFourDLT(DimLevelType::CompressedWithHiNuNo)),
+ "isBlockTwoOutOfFourDLT definition is broken");
} // namespace sparse_tensor
} // namespace mlir
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index 68dcab6e64c7e45..500e34a4da53660 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -683,7 +683,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
if (isDenseDLT(dlt))
return parentSz * getLvlSizes()[l];
MLIR_SPARSETENSOR_FATAL("unsupported level type: %d\n",
- static_cast<uint8_t>(dlt));
+ static_cast<uint16_t>(dlt));
}
/// Initializes sparse tensor storage scheme from a memory-resident sparse
diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
index e18da1027e0f33a..f617c6af5da16cf 100644
--- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp
+++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
@@ -21,24 +21,24 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor,
// Ensure the C-API enums are int-castable to C++ equivalents.
static_assert(
- static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE) ==
- static_cast<int>(DimLevelType::Dense) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) ==
- static_cast<int>(DimLevelType::Compressed) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU) ==
- static_cast<int>(DimLevelType::CompressedNu) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO) ==
- static_cast<int>(DimLevelType::CompressedNo) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO) ==
- static_cast<int>(DimLevelType::CompressedNuNo) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON) ==
- static_cast<int>(DimLevelType::Singleton) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU) ==
- static_cast<int>(DimLevelType::SingletonNu) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO) ==
- static_cast<int>(DimLevelType::SingletonNo) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO) ==
- static_cast<int>(DimLevelType::SingletonNuNo),
+ static_cast<uint16_t>(MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE) ==
+ static_cast<uint16_t>(DimLevelType::Dense) &&
+ static_cast<uint16_t>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) ==
+ static_cast<uint16_t>(DimLevelType::Compressed) &&
+ static_cast<uint16_t>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU) ==
+ static_cast<uint16_t>(DimLevelType::CompressedNu) &&
+ static_cast<uint16_t>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO) ==
+ static_cast<uint16_t>(DimLevelType::CompressedNo) &&
+ static_cast<uint16_t>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO) ==
+ static_cast<uint16_t>(DimLevelType::CompressedNuNo) &&
+ static_cast<uint16_t>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON) ==
+ static_cast<uint16_t>(DimLevelType::Singleton) &&
+ static_cast<uint16_t>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU) ==
+ static_cast<uint16_t>(DimLevelType::SingletonNu) &&
+ static_cast<uint16_t>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO) ==
+ static_cast<uint16_t>(DimLevelType::SingletonNo) &&
+ static_cast<uint16_t>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO) ==
+ static_cast<uint16_t>(DimLevelType::SingletonNuNo),
"MlirSparseTensorDimLevelType (C-API) and DimLevelType (C++) mismatch");
bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index 44a2c7d49619405..85468023fa1b25c 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -47,7 +47,7 @@ using namespace mlir::sparse_tensor::ir_detail;
// `LvlTypeParser` implementation.
//===----------------------------------------------------------------------===//
-FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
+FailureOr<uint16_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
StringRef base;
const auto loc = parser.getCurrentLocation();
ERROR_IF(failed(parser.parseOptionalKeyword(&base)),
@@ -62,17 +62,11 @@ FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
// Set the base bit for properties.
if (base.compare("dense") == 0) {
- properties |= static_cast<uint8_t>(LevelFormat::Dense);
+ properties |= static_cast<uint16_t>(LevelFormat::Dense);
} else if (base.compare("compressed") == 0) {
- // TODO: Remove this condition once dimLvlType enum is refactored. Current
- // enum treats High and TwoOutOfFour as formats instead of properties.
- if (!(properties & static_cast<uint8_t>(LevelNondefaultProperty::High) ||
- properties &
- static_cast<uint8_t>(LevelNondefaultProperty::Block2_4))) {
- properties |= static_cast<uint8_t>(LevelFormat::Compressed);
- }
+ properties |= static_cast<uint16_t>(LevelFormat::Compressed);
} else if (base.compare("singleton") == 0) {
- properties |= static_cast<uint8_t>(LevelFormat::Singleton);
+ properties |= static_cast<uint16_t>(LevelFormat::Singleton);
} else {
parser.emitError(loc, "unknown level format: ") << base;
return failure();
@@ -84,19 +78,19 @@ FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
}
ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
- uint8_t *properties) const {
+ uint16_t *properties) const {
StringRef strVal;
auto loc = parser.getCurrentLocation();
ERROR_IF(failed(parser.parseOptionalKeyword(&strVal)),
"expected valid level property (e.g. nonordered, nonunique or high)")
if (strVal.compare("nonunique") == 0) {
- *properties |= static_cast<uint8_t>(LevelNondefaultProperty::Nonunique);
+ *properties |= static_cast<uint16_t>(LevelNondefaultProperty::Nonunique);
} else if (strVal.compare("nonordered") == 0) {
- *properties |= static_cast<uint8_t>(LevelNondefaultProperty::Nonordered);
+ *properties |= static_cast<uint16_t>(LevelNondefaultProperty::Nonordered);
} else if (strVal.compare("high") == 0) {
- *properties |= static_cast<uint8_t>(LevelNondefaultProperty::High);
+ *properties |= static_cast<uint16_t>(LevelNondefaultProperty::High);
} else if (strVal.compare("block2_4") == 0) {
- *properties |= static_cast<uint8_t>(LevelNondefaultProperty::Block2_4);
+ *properties |= static_cast<uint16_t>(LevelNondefaultProperty::Block2_4);
} else {
parser.emitError(loc, "unknown level property: ") << strVal;
return failure();
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
index 10fb6c8f1c04730..3fd0252e90c761c 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
@@ -18,10 +18,10 @@ namespace ir_detail {
class LvlTypeParser {
public:
LvlTypeParser() = default;
- FailureOr<uint8_t> parseLvlType(AsmParser &parser) const;
+ FailureOr<uint16_t> parseLvlType(AsmParser &parser) const;
private:
- ParseResult parseProperty(AsmParser &parser, uint8_t *properties) const;
+ ParseResult parseProperty(AsmParser &parser, uint16_t *properties) const;
};
} // namespace ir_detail
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 9675a61109477b5..ffe2d917b1ef6c1 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -747,15 +747,19 @@ RankedTensorType sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType rtt,
// An unordered and non-unique compressed level at beginning.
// If this is also the last level, then it is unique.
- lvlTypes.push_back(
- *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
+ lvlTypes.push_back(*buildLevelType(LevelFormat::Compressed,
+ toPropertyBits(ordered, lvlRank == 1)));
if (lvlRank > 1) {
// TODO: it is actually ordered at the level for ordered input.
// Followed by unordered non-unique n-2 singleton levels.
- std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
- *buildLevelType(LevelFormat::Singleton, ordered, false));
+ std::fill_n(
+ std::back_inserter(lvlTypes), lvlRank - 2,
+ *buildLevelType(LevelFormat::Singleton,
+ toPropertyBits(ordered, false)));
// Ends by a unique singleton level unless the lvlRank is 1.
- lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
+ lvlTypes.push_back(
+ *buildLevelType(LevelFormat::Singleton,
+ toPropertyBits(ordered, true)));
}
// TODO: Maybe pick the bitwidth based on input/output tensors (probably the
@@ -831,7 +835,8 @@ static SparseTensorEncodingAttr
getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
SmallVector<DimLevelType> dlts;
for (auto dlt : enc.getLvlTypes())
- dlts.push_back(*buildLevelType(*getLevelFormat(dlt), true, true));
+ dlts.push_back(*buildLevelType(*getLevelFormat(dlt),
+ toPropertyBits(true, true)));
return SparseTensorEncodingAttr::get(
enc.getContext(), dlts,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index a6468b3e14795f7..050fbee96f77a43 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -412,7 +412,7 @@ inline Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc,
/// Generates a constant of the internal dimension level type encoding.
inline Value constantDimLevelTypeEncoding(OpBuilder &builder, Location loc,
DimLevelType dlt) {
- return constantI8(builder, loc, static_cast<uint8_t>(dlt));
+ return constantI16(builder, loc, static_cast<uint16_t>(dlt));
}
inline bool isZeroRankedTensorOrScalar(Type type) {
diff --git a/mlir/lib/ExecutionEngine/SparseTensor/NNZ.cpp b/mlir/lib/ExecutionEngine/SparseTensor/NNZ.cpp
index c6fd669ad513f46..1ff491483c0e744 100644
--- a/mlir/lib/ExecutionEngine/SparseTensor/NNZ.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensor/NNZ.cpp
@@ -47,7 +47,7 @@ SparseTensorNNZ::SparseTensorNNZ(const std::vector<uint64_t> &lvlSizes,
// for dense-after-compressed.
} else {
MLIR_SPARSETENSOR_FATAL("unsupported level type: %d\n",
- static_cast<uint8_t>(dlt));
+ static_cast<uint16_t>(dlt));
}
sz = detail::checkedMul(sz, lvlSizes[l]);
}
diff --git a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
index 2c4f0123ed4417c..4dd42695f3a25a8 100644
--- a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
@@ -55,7 +55,7 @@ SparseTensorStorageBase::SparseTensorStorageBase( // NOLINT
// when `NDEBUG` is true), thereby reducing the need to re-assert things.
if (!(isDenseDLT(dlt) || isCompressedDLT(dlt) || isSingletonDLT(dlt)))
MLIR_SPARSETENSOR_FATAL("unsupported level type: %d\n",
- static_cast<uint8_t>(dlt));
+ static_cast<uint16_t>(dlt));
}
}
diff --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c
index 6449a8f0c79403c..6746e4777674811 100644
--- a/mlir/test/CAPI/sparse_tensor.c
+++ b/mlir/test/CAPI/sparse_tensor.c
@@ -38,9 +38,9 @@ static int testRoundtripEncoding(MlirContext ctx) {
mlirSparseTensorEncodingAttrGetDimToLvl(originalAttr);
// CHECK: (d0, d1)[s0] -> (s0, d0, d1)
mlirAffineMapDump(dimToLvl);
- // CHECK: level_type: 4
- // CHECK: level_type: 8
- // CHECK: level_type: 8
+ // CHECK: level_type: 1
+ // CHECK: level_type: 2
+ // CHECK: level_type: 2
int lvlRank = mlirSparseTensorEncodingGetLvlRank(originalAttr);
enum MlirSparseTensorDimLevelType *lvlTypes =
malloc(sizeof(enum MlirSparseTensorDimLevelType) * lvlRank);
diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 9cc5cc01544ccbe..5ef37e2f666ad41 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -80,8 +80,8 @@ func.func @sparse_dim3d_const(%arg0: tensor<10x20x30xf64, #SparseTensor>) -> ind
// CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
// CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<1xindex>
// CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<1xindex> to memref<?xindex>
-// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<1xi8>
-// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<1xi8> to memref<?xi8>
+// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<1xi16>
+// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<1xi16> to memref<?xi16>
// CHECK: %[[T:.*]] = call @newSparseTensorFromReader(%[[Reader]], %[[DimShape]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}})
// CHECK: call @delSparseTensorReader(%[[Reader]])
// CHECK: return %[[T]] : !llvm.ptr<i8>
@@ -98,8 +98,8 @@ func.func @sparse_new1d(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector>
// CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]])
// CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi8>
-// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi8> to memref<?xi8>
+// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi16>
+// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi16> to memref<?xi16>
// CHECK: %[[T:.*]] = call @newSparseTensorFromReader(%[[Reader]], %[[DimSizes]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}})
// CHECK: call @delSparseTensorReader(%[[Reader]])
// CHECK: return %[[T]] : !llvm.ptr<i8>
@@ -120,8 +120,8 @@ func.func @sparse_new2d(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #CSR> {
// CHECK-DAG: %[[Lvl2Dim:.*]] = memref.cast %[[Lvl2Dim0]] : memref<3xindex> to memref<?xindex>
// CHECK-DAG: %[[Dim2Lvl0:.*]] = memref.alloca() : memref<3xindex>
// CHECK-DAG: %[[Dim2Lvl:.*]] = memref.cast %[[Dim2Lvl0]] : memref<3xindex> to memref<?xindex>
-// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi8>
-// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi8> to memref<?xi8>
+// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi16>
+// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi16> to memref<?xi16>
// CHECK: %[[T:.*]] = call @newSparseTensorFromReader(%[[Reader]], %[[LvlSizes]], %[[LvlTypes]], %[[Lvl2Dim]], %[[Dim2Lvl]], %{{.*}}, %{{.*}}, %{{.*}})
// CHECK: call @delSparseTensorReader(%[[Reader]])
// CHECK: return %[[T]] : !llvm.ptr<i8>
@@ -138,11 +138,11 @@ func.func @sparse_new3d(%arg0: !llvm.ptr<i8>) -> tensor<?x?x?xf32, #SparseTensor
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[DimSizes0:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[LvlSizes0:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi8>
+// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi16>
// CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[DimSizes:.*]] = memref.cast %[[DimSizes0]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: %[[LvlSizes:.*]] = memref.cast %[[LvlSizes0]] : memref<2xindex> to memref<?xindex>
-// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi8> to memref<?xi8>
+// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi16> to memref<?xi16>
// CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: memref.store %[[I]], %[[DimSizes0]][%[[C0]]] : memref<2xindex>
// CHECK-DAG: memref.store %[[J]], %[[DimSizes0]][%[[C1]]] : memref<2xindex>
diff --git a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
index 9fb1946d56263db..8a74a0ab3634570 100644
--- a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
@@ -26,11 +26,11 @@
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[U:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?xi32>
-// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<1xi8>
+// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<1xi16>
// CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<1xindex>
// CHECK-DAG: %[[LvlSizes:.*]] = memref.alloca() : memref<1xindex>
// CHECK-DAG: %[[Iota:.*]] = memref.alloca() : memref<1xindex>
-// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<1xi8> to memref<?xi8>
+// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<1xi16> to memref<?xi16>
// CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<1xindex> to memref<?xindex>
// CHECK-DAG: %[[LvlSizesP:.*]] = memref.cast %[[LvlSizes]] : memref<1xindex> to memref<?xindex>
// CHECK-DAG: %[[IotaP:.*]] = memref.cast %[[Iota]] : memref<1xindex> to memref<?xindex>
@@ -84,11 +84,11 @@ func.func @sparse_convert_complex(%arg0: tensor<100xcomplex<f64>>) -> tensor<100
// CHECK-DAG: %[[FromCOO:.*]] = arith.constant 2 : i32
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<2xi8>
+// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<2xi16>
// CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[LvlSizes:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[Iota:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<2xi8> to memref<?xi8>
+// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<2xi16> to memref<?xi16>
// CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: %[[LvlSizesP:.*]] = memref.cast %[[LvlSizes]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: %[[IotaP:.*]] = memref.cast %[[Iota]] : memref<2xindex> to memref<?xindex>
@@ -138,11 +138,11 @@ func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #CSR> {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<2xi8>
+// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<2xi16>
// CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[LvlSizes:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[Iota:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<2xi8> to memref<?xi8>
+// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<2xi16> to memref<?xi16>
// CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: %[[LvlSizesP:.*]] = memref.cast %[[LvlSizes]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: %[[IotaP:.*]] = memref.cast %[[Iota]] : memref<2xindex> to memref<?xindex>
@@ -210,12 +210,12 @@ func.func @sparse_constant_csc() -> tensor<8x7xf32, #CSC>{
// CHECK-DAG: %[[U1:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x?x?xf64>
// CHECK-DAG: %[[U2:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?x?xf64>
// CHECK-DAG: %[[U3:.*]] = tensor.dim %[[A]], %[[C2]] : tensor<?x?x?xf64>
-// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<3xi8>
+// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<3xi16>
// CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<3xindex>
// CHECK-DAG: %[[LvlSizes:.*]] = memref.alloca() : memref<3xindex>
// CHECK-DAG: %[[Lvl2Dim:.*]] = memref.alloca() : memref<3xindex>
// CHECK-DAG: %[[Dim2Lvl:.*]] = memref.alloca() : memref<3xindex>
-// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<3xi8> to memref<?xi8>
+// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<3xi16> to memref<?xi16>
// CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<3xindex> to memref<?xindex>
// CHECK-DAG: %[[LvlSizesP:.*]] = memref.cast %[[LvlSizes]] : memref<3xindex> to memref<?xindex>
// CHECK-DAG: %[[Lvl2DimP:.*]] = memref.cast %[[Lvl2Dim]] : memref<3xindex> to memref<?xindex>
diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
index 621235182c9a840..2ef50cbd7d82adc 100644
--- a/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
@@ -19,13 +19,13 @@
// CHECK-SAME: %[[Arg:.*]]: !llvm.ptr<i8>) -> tensor<13xi32>
// CHECK-DAG: %[[I0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[I13:.*]] = arith.constant 13 : index
-// CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 4 : i8
+// CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 1 : i16
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : i32
//
-// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<1xi8>
-// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<1xi8> to memref<?xi8>
-// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I0]]] : memref<1xi8>
+// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<1xi16>
+// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<1xi16> to memref<?xi16>
+// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I0]]] : memref<1xi16>
// CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<1xindex>
// CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<1xindex> to memref<?xindex>
// CHECK-DAG: memref.store %[[I13]], %[[DimSizes]][%[[I0]]] : memref<1xindex>
@@ -60,13 +60,13 @@ func.func @sparse_convert_1d(%arg0: tensor<13xi32, #SparseVector>) -> tensor<13x
// CHECK-LABEL: func @sparse_convert_1d_dyn(
// CHECK-SAME: %[[Arg:.*]]: !llvm.ptr<i8>) -> tensor<?xi32>
// CHECK-DAG: %[[I0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 4 : i8
+// CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 1 : i16
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : i32
//
-// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<1xi8>
-// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<1xi8> to memref<?xi8>
-// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I0]]] : memref<1xi8>
+// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<1xi16>
+// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<1xi16> to memref<?xi16>
+// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I0]]] : memref<1xi16>
// CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<1xindex>
// CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<1xindex> to memref<?xindex>
// CHECK-DAG: %[[SizeI0:.*]] = call @sparseDimSize(%[[Arg]], %[[I0]]) : (!llvm.ptr<i8>, index) -> index
@@ -105,14 +105,14 @@ func.func @sparse_convert_1d_dyn(%arg0: tensor<?xi32, #SparseVector>) -> tensor<
// CHECK-DAG: %[[I1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[I2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[I4:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 4 : i8
+// CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 1 : i16
// CHECK-DAG: %[[ActionToIter:.*]] = arith.constant 6 : i32
// CHECK-DAG: %[[E0:.*]] = arith.constant 0.000000e+00 : f64
//
-// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<2xi8>
-// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<2xi8> to memref<?xi8>
-// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I0]]] : memref<2xi8>
-// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I1]]] : memref<2xi8>
+// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<2xi16>
+// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<2xi16> to memref<?xi16>
+// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I0]]] : memref<2xi16>
+// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I1]]] : memref<2xi16>
// CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: memref.store %[[I2]], %[[DimSizes]][%[[I0]]] : memref<2xindex>
@@ -164,14 +164,14 @@ func.func @sparse_convert_2d(%arg0: tensor<2x4xf64, #SparseMatrix>) -> tensor<2x
// CHECK-DAG: %[[I0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[I1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[I4:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 4 : i8
+// CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 1 : i16
// CHECK-DAG: %[[ActionToIter:.*]] = arith.constant 6 : i32
// CHECK-DAG: %[[E0:.*]] = arith.constant 0.000000e+00 : f64
//
-// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<2xi8>
-// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<2xi8> to memref<?xi8>
-// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I0]]] : memref<2xi8>
-// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I1]]] : memref<2xi8>
+// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<2xi16>
+// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<2xi16> to memref<?xi16>
+// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I0]]] : memref<2xi16>
+// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I1]]] : memref<2xi16>
// CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: %[[SizeI0:.*]] = call @sparseDimSize(%[[Arg]], %[[I0]]) : (!llvm.ptr<i8>, index) -> index
@@ -212,14 +212,14 @@ func.func @sparse_convert_2d_dyn0(%arg0: tensor<?x4xf64, #SparseMatrix>) -> tens
// CHECK-DAG: %[[I0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[I1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[I2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 4 : i8
+// CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 1 : i16
// CHECK-DAG: %[[ActionToIter:.*]] = arith.constant 6 : i32
// CHECK-DAG: %[[E0:.*]] = arith.constant 0.000000e+00 : f64
//
-// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<2xi8>
-// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<2xi8> to memref<?xi8>
-// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I0]]] : memref<2xi8>
-// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I1]]] : memref<2xi8>
+// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<2xi16>
+// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<2xi16> to memref<?xi16>
+// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I0]]] : memref<2xi16>
+// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I1]]] : memref<2xi16>
// CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: %[[SizeI1:.*]] = call @sparseDimSize(%[[Arg]], %[[I1]]) : (!llvm.ptr<i8>, index) -> index
@@ -259,14 +259,14 @@ func.func @sparse_convert_2d_dyn1(%arg0: tensor<2x?xf64, #SparseMatrix>) -> tens
// CHECK-SAME: %[[Arg:.*]]: !llvm.ptr<i8>) -> tensor<?x?xf64>
// CHECK-DAG: %[[I0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[I1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 4 : i8
+// CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 1 : i16
// CHECK-DAG: %[[ActionToIter:.*]] = arith.constant 6 : i32
// CHECK-DAG: %[[E0:.*]] = arith.constant 0.000000e+00 : f64
//
-// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<2xi8>
-// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<2xi8> to memref<?xi8>
-// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I0]]] : memref<2xi8>
-// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I1]]] : memref<2xi8>
+// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<2xi16>
+// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<2xi16> to memref<?xi16>
+// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I0]]] : memref<2xi16>
+// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I1]]] : memref<2xi16>
// CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: %[[SizeI0:.*]] = call @sparseDimSize(%[[Arg]], %[[I0]]) : (!llvm.ptr<i8>, index) -> index
@@ -326,15 +326,15 @@ func.func @sparse_convert_2d_dyn2(%arg0: tensor<?x?xf64, #SparseMatrix>) -> tens
// CHECK-DAG: %[[I2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[I3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[I4:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 4 : i8
+// CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 1 : i16
// CHECK-DAG: %[[ActionToIter:.*]] = arith.constant 6 : i32
// CHECK-DAG: %[[E0:.*]] = arith.constant 0.000000e+00 : f64
//
-// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<3xi8>
-// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<3xi8> to memref<?xi8>
-// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I0]]] : memref<3xi8>
-// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I1]]] : memref<3xi8>
-// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I2]]] : memref<3xi8>
+// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<3xi16>
+// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<3xi16> to memref<?xi16>
+// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I0]]] : memref<3xi16>
+// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I1]]] : memref<3xi16>
+// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes]][%[[I2]]] : memref<3xi16>
// CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<3xindex>
// CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<3xindex> to memref<?xindex>
// CHECK-DAG: memref.store %[[I2]], %[[DimSizes]][%[[I0]]] : memref<3xindex>
diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
index 54cdfc690952d9a..dcff3f622df3410 100644
--- a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
@@ -62,11 +62,11 @@ func.func @sparse_hidden_nop_cast(%arg0: tensor<32xf32, #SparseVector>) -> tenso
// CHECK-LABEL: func @sparse_convert_1d_ss(
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
// CHECK-DAG: %[[SparseToSparse:.*]] = arith.constant 3 : i32
-// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<1xi8>
+// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<1xi16>
// CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<1xindex>
// CHECK-DAG: %[[LvlSizes:.*]] = memref.alloca() : memref<1xindex>
// CHECK-DAG: %[[Iota:.*]] = memref.alloca() : memref<1xindex>
-// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<1xi8> to memref<?xi8>
+// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<1xi16> to memref<?xi16>
// CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<1xindex> to memref<?xindex>
// CHECK-DAG: %[[LvlSizesP:.*]] = memref.cast %[[LvlSizes]] : memref<1xindex> to memref<?xindex>
// CHECK-DAG: %[[IotaP:.*]] = memref.cast %[[Iota]] : memref<1xindex> to memref<?xindex>
@@ -81,11 +81,11 @@ func.func @sparse_convert_1d_ss(%arg0: tensor<?xf32, #SparseVector64>) -> tensor
// CHECK-COO-SAME: %[[A:.*]]: !llvm.ptr<i8>)
// CHECK-COO-DAG: %[[ToCOO:.*]] = arith.constant 5 : i32
// CHECK-COO-DAG: %[[FromCOO:.*]] = arith.constant 2 : i32
-// CHECK-COO-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<1xi8>
+// CHECK-COO-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<1xi16>
// CHECK-COO-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<1xindex>
// CHECK-COO-DAG: %[[LvlSizes:.*]] = memref.alloca() : memref<1xindex>
// CHECK-COO-DAG: %[[Iota:.*]] = memref.alloca() : memref<1xindex>
-// CHECK-COO-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<1xi8> to memref<?xi8>
+// CHECK-COO-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<1xi16> to memref<?xi16>
// CHECK-COO-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<1xindex> to memref<?xindex>
// CHECK-COO-DAG: %[[LvlSizesP:.*]] = memref.cast %[[LvlSizes]] : memref<1xindex> to memref<?xindex>
// CHECK-COO-DAG: %[[IotaP:.*]] = memref.cast %[[Iota]] : memref<1xindex> to memref<?xindex>
@@ -97,11 +97,11 @@ func.func @sparse_convert_1d_ss(%arg0: tensor<?xf32, #SparseVector64>) -> tensor
// CHECK-AUTO-LABEL: func @sparse_convert(
// CHECK-AUTO-SAME: %[[A:.*]]: !llvm.ptr<i8>)
// CHECK-AUTO-DAG: %[[SparseToSparse:.*]] = arith.constant 3 : i32
-// CHECK-AUTO-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<1xi8>
+// CHECK-AUTO-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<1xi16>
// CHECK-AUTO-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<1xindex>
// CHECK-AUTO-DAG: %[[LvlSizes:.*]] = memref.alloca() : memref<1xindex>
// CHECK-AUTO-DAG: %[[Iota:.*]] = memref.alloca() : memref<1xindex>
-// CHECK-AUTO-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<1xi8> to memref<?xi8>
+// CHECK-AUTO-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<1xi16> to memref<?xi16>
// CHECK-AUTO-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<1xindex> to memref<?xindex>
// CHECK-AUTO-DAG: %[[LvlSizesP:.*]] = memref.cast %[[LvlSizes]] : memref<1xindex> to memref<?xindex>
// CHECK-AUTO-DAG: %[[IotaP:.*]] = memref.cast %[[Iota]] : memref<1xindex> to memref<?xindex>
@@ -129,11 +129,11 @@ func.func @sparse_convert(%arg0: tensor<?xf32, #SparseVector64>) -> tensor<?xf32
// CHECK-COO-SAME: %[[A:.*]]: !llvm.ptr<i8>)
// CHECK-COO-DAG: %[[ToCOO:.*]] = arith.constant 5 : i32
// CHECK-COO-DAG: %[[FromCOO:.*]] = arith.constant 2 : i32
-// CHECK-COO-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<1xi8>
+// CHECK-COO-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<1xi16>
// CHECK-COO-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<1xindex>
// CHECK-COO-DAG: %[[LvlSizes:.*]] = memref.alloca() : memref<1xindex>
// CHECK-COO-DAG: %[[Iota:.*]] = memref.alloca() : memref<1xindex>
-// CHECK-COO-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<1xi8> to memref<?xi8>
+// CHECK-COO-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<1xi16> to memref<?xi16>
// CHECK-COO-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<1xindex> to memref<?xindex>
// CHECK-COO-DAG: %[[LvlSizesP:.*]] = memref.cast %[[LvlSizes]] : memref<1xindex> to memref<?xindex>
// CHECK-COO-DAG: %[[IotaP:.*]] = memref.cast %[[Iota]] : memref<1xindex> to memref<?xindex>
@@ -145,11 +145,11 @@ func.func @sparse_convert(%arg0: tensor<?xf32, #SparseVector64>) -> tensor<?xf32
// CHECK-AUTO-LABEL: func @sparse_convert_singleton(
// CHECK-AUTO-SAME: %[[A:.*]]: !llvm.ptr<i8>)
// CHECK-AUTO-DAG: %[[SparseToSparse:.*]] = arith.constant 3 : i32
-// CHECK-AUTO-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<1xi8>
+// CHECK-AUTO-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<1xi16>
// CHECK-AUTO-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<1xindex>
// CHECK-AUTO-DAG: %[[LvlSizes:.*]] = memref.alloca() : memref<1xindex>
// CHECK-AUTO-DAG: %[[Iota:.*]] = memref.alloca() : memref<1xindex>
-// CHECK-AUTO-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<1xi8> to memref<?xi8>
+// CHECK-AUTO-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<1xi16> to memref<?xi16>
// CHECK-AUTO-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<1xindex> to memref<?xindex>
// CHECK-AUTO-DAG: %[[LvlSizesP:.*]] = memref.cast %[[LvlSizes]] : memref<1xindex> to memref<?xindex>
// CHECK-AUTO-DAG: %[[IotaP:.*]] = memref.cast %[[Iota]] : memref<1xindex> to memref<?xindex>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
index 5f412e59dba9f87..61412da867e4992 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
@@ -13,7 +13,7 @@
// CHECK-DAG: %[[TMP_c6_i32:.*]] = arith.constant 6 : i32
// CHECK-DAG: %[[TMP_c1_i32:.*]] = arith.constant 1 : i32
// CHECK-DAG: %[[TMP_c0_i32:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[TMP_c8_i8:.*]] = arith.constant 8 : i8
+// CHECK-DAG: %[[CompressedDLT:.*]] = arith.constant 2 : i16
// CHECK-DAG: %[[TMP_c3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[TMP_cst:.*]] = arith.constant 0.000000e+00 : f64
@@ -30,10 +30,10 @@
// CHECK: }
// CHECK: }
// CHECK: }
-// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<2xi8>
-// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<2xi8> to memref<?xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes]][%[[TMP_c0]]] : memref<2xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes]][%[[TMP_c1]]] : memref<2xi8>
+// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<2xi16>
+// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<2xi16> to memref<?xi16>
+// CHECK-DAG: memref.store %[[CompressedDLT]], %[[LvlTypes]][%[[TMP_c0]]] : memref<2xi16>
+// CHECK-DAG: memref.store %[[CompressedDLT]], %[[LvlTypes]][%[[TMP_c1]]] : memref<2xi16>
// CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: memref.store %[[TMP_c3]], %[[DimSizes]][%[[TMP_c0]]] : memref<2xindex>
@@ -84,11 +84,11 @@ func.func @concat_mix_dense(%arg0: tensor<2x4xf64>, %arg1: tensor<3x4xf64, #Spar
// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[TMP_c5:.*]] = arith.constant 5 : index
// CHECK-DAG: %[[TMP_c4:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[TMP_c8_i8:.*]] = arith.constant 8 : i8
-// CHECK-DAG: %[[LvlTypes_0:.*]] = memref.alloca() : memref<2xi8>
-// CHECK-DAG: %[[LvlTypesP_0:.*]] = memref.cast %[[LvlTypes_0]] : memref<2xi8> to memref<?xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes_0]][%[[TMP_c0]]] : memref<2xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes_0]][%[[TMP_c1]]] : memref<2xi8>
+// CHECK-DAG: %[[CompressedDLT:.*]] = arith.constant 2 : i16
+// CHECK-DAG: %[[LvlTypes_0:.*]] = memref.alloca() : memref<2xi16>
+// CHECK-DAG: %[[LvlTypesP_0:.*]] = memref.cast %[[LvlTypes_0]] : memref<2xi16> to memref<?xi16>
+// CHECK-DAG: memref.store %[[CompressedDLT]], %[[LvlTypes_0]][%[[TMP_c0]]] : memref<2xi16>
+// CHECK-DAG: memref.store %[[CompressedDLT]], %[[LvlTypes_0]][%[[TMP_c1]]] : memref<2xi16>
// CHECK-DAG: %[[DimSizes_0:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[DimSizesP_0:.*]] = memref.cast %[[DimSizes_0]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: memref.store %[[TMP_c5]], %[[DimSizes_0]][%[[TMP_c0]]] : memref<2xindex>
@@ -116,10 +116,10 @@ func.func @concat_mix_dense(%arg0: tensor<2x4xf64>, %arg1: tensor<3x4xf64, #Spar
// CHECK: }
// CHECK: }
// CHECK: }
-// CHECK-DAG: %[[LvlTypes_1:.*]] = memref.alloca() : memref<2xi8>
-// CHECK-DAG: %[[LvlTypesP_1:.*]] = memref.cast %[[LvlTypes_1]] : memref<2xi8> to memref<?xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes_1]][%[[TMP_c0]]] : memref<2xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes_1]][%[[TMP_c1]]] : memref<2xi8>
+// CHECK-DAG: %[[LvlTypes_1:.*]] = memref.alloca() : memref<2xi16>
+// CHECK-DAG: %[[LvlTypesP_1:.*]] = memref.cast %[[LvlTypes_1]] : memref<2xi16> to memref<?xi16>
+// CHECK-DAG: memref.store %[[CompressedDLT]], %[[LvlTypes_1]][%[[TMP_c0]]] : memref<2xi16>
+// CHECK-DAG: memref.store %[[CompressedDLT]], %[[LvlTypes_1]][%[[TMP_c1]]] : memref<2xi16>
// CHECK-DAG: %[[DimSizes_1:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[DimSizesP_1:.*]] = memref.cast %[[DimSizes_1]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: memref.store %[[TMP_c3]], %[[DimSizes_1]][%[[TMP_c0]]] : memref<2xindex>
@@ -172,11 +172,11 @@ func.func @concat_mix_sparse(%arg0: tensor<2x4xf64>, %arg1: tensor<3x4xf64, #Spa
// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[TMP_c4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[TMP_c5:.*]] = arith.constant 5 : index
-// CHECK-DAG: %[[TMP_c8_i8:.*]] = arith.constant 8 : i8
-// CHECK-DAG: %[[LvlTypes_0:.*]] = memref.alloca() : memref<2xi8>
-// CHECK-DAG: %[[LvlTypesP_0:.*]] = memref.cast %[[LvlTypes_0]] : memref<2xi8> to memref<?xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes_0]][%[[TMP_c0]]] : memref<2xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes_0]][%[[TMP_c1]]] : memref<2xi8>
+// CHECK-DAG: %[[CompressedDLT:.*]] = arith.constant 2 : i16
+// CHECK-DAG: %[[LvlTypes_0:.*]] = memref.alloca() : memref<2xi16>
+// CHECK-DAG: %[[LvlTypesP_0:.*]] = memref.cast %[[LvlTypes_0]] : memref<2xi16> to memref<?xi16>
+// CHECK-DAG: memref.store %[[CompressedDLT]], %[[LvlTypes_0]][%[[TMP_c0]]] : memref<2xi16>
+// CHECK-DAG: memref.store %[[CompressedDLT]], %[[LvlTypes_0]][%[[TMP_c1]]] : memref<2xi16>
// CHECK-DAG: %[[DimSizes_0:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[DimSizesP_0:.*]] = memref.cast %[[DimSizes_0]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: memref.store %[[TMP_c4]], %[[DimSizes_0]][%[[TMP_c0]]] : memref<2xindex>
@@ -206,10 +206,10 @@ func.func @concat_mix_sparse(%arg0: tensor<2x4xf64>, %arg1: tensor<3x4xf64, #Spa
// CHECK: }
// CHECK: }
// CHECK: }
-// CHECK-DAG: %[[LvlTypes_1:.*]] = memref.alloca() : memref<2xi8>
-// CHECK-DAG: %[[LvlTypesP_1:.*]] = memref.cast %[[LvlTypes_1]] : memref<2xi8> to memref<?xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes_1]][%[[TMP_c0]]] : memref<2xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes_1]][%[[TMP_c1]]] : memref<2xi8>
+// CHECK-DAG: %[[LvlTypes_1:.*]] = memref.alloca() : memref<2xi16>
+// CHECK-DAG: %[[LvlTypesP_1:.*]] = memref.cast %[[LvlTypes_1]] : memref<2xi16> to memref<?xi16>
+// CHECK-DAG: memref.store %[[CompressedDLT]], %[[LvlTypes_1]][%[[TMP_c0]]] : memref<2xi16>
+// CHECK-DAG: memref.store %[[CompressedDLT]], %[[LvlTypes_1]][%[[TMP_c1]]] : memref<2xi16>
// CHECK-DAG: %[[DimSizes_1:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[DimSizesP_1:.*]] = memref.cast %[[DimSizes_1]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: memref.store %[[TMP_c4]], %[[DimSizes_1]][%[[TMP_c0]]] : memref<2xindex>
@@ -254,7 +254,7 @@ func.func @concat_mix_sparse_perm_dim1(%arg0: tensor<4x2xf64>, %arg1: tensor<4x3
// CHECK-DAG: %[[TMP_c6_i32:.*]] = arith.constant 6 : i32
// CHECK-DAG: %[[TMP_c1_i32:.*]] = arith.constant 1 : i32
// CHECK-DAG: %[[TMP_c0_i32:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[TMP_c8_i8:.*]] = arith.constant 8 : i8
+// CHECK-DAG: %[[CompressedDLT:.*]] = arith.constant 2 : i16
// CHECK-DAG: %[[TMP_c3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[TMP_cst:.*]] = arith.constant 0.000000e+00 : f64
@@ -271,10 +271,10 @@ func.func @concat_mix_sparse_perm_dim1(%arg0: tensor<4x2xf64>, %arg1: tensor<4x3
// CHECK: }
// CHECK: }
// CHECK: }
-// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<2xi8>
-// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<2xi8> to memref<?xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes]][%[[TMP_c0]]] : memref<2xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes]][%[[TMP_c1]]] : memref<2xi8>
+// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<2xi16>
+// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<2xi16> to memref<?xi16>
+// CHECK-DAG: memref.store %[[CompressedDLT]], %[[LvlTypes]][%[[TMP_c0]]] : memref<2xi16>
+// CHECK-DAG: memref.store %[[CompressedDLT]], %[[LvlTypes]][%[[TMP_c1]]] : memref<2xi16>
// CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: memref.store %[[TMP_c4]], %[[DimSizes]][%[[TMP_c0]]] : memref<2xindex>
@@ -317,7 +317,7 @@ func.func @concat_mix_dense_perm_dim1(%arg0: tensor<4x2xf64>, %arg1: tensor<4x3x
// CHECK-DAG: %[[TMP_c6_i32:.*]] = arith.constant 6 : i32
// CHECK-DAG: %[[TMP_c1_i32:.*]] = arith.constant 1 : i32
// CHECK-DAG: %[[TMP_c0_i32:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[TMP_c8_i8:.*]] = arith.constant 8 : i8
+// CHECK-DAG: %[[CompressedDLT:.*]] = arith.constant 2 : i16
// CHECK-DAG: %[[TMP_cst:.*]] = arith.constant 0.000000e+00 : f64
// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[TMP_c3:.*]] = arith.constant 3 : index
@@ -334,10 +334,10 @@ func.func @concat_mix_dense_perm_dim1(%arg0: tensor<4x2xf64>, %arg1: tensor<4x3x
// CHECK: }
// CHECK: }
// CHECK: }
-// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<2xi8>
-// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<2xi8> to memref<?xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes]][%[[TMP_c0]]] : memref<2xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes]][%[[TMP_c1]]] : memref<2xi8>
+// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<2xi16>
+// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<2xi16> to memref<?xi16>
+// CHECK-DAG: memref.store %[[CompressedDLT]], %[[LvlTypes]][%[[TMP_c0]]] : memref<2xi16>
+// CHECK-DAG: memref.store %[[CompressedDLT]], %[[LvlTypes]][%[[TMP_c1]]] : memref<2xi16>
// CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: memref.store %[[TMP_c3]], %[[DimSizes]][%[[TMP_c0]]] : memref<2xindex>
@@ -378,8 +378,8 @@ func.func @concat_mix_dense_perm_dim1_dyn(%arg0: tensor<3x2xf64>, %arg1: tensor<
// CHECK-SAME: %[[TMP_arg1:.*]]: !llvm.ptr<i8>)
// CHECK-DAG: %[[TMP_c2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[TMP_c6_i32:.*]] = arith.constant 6 : i32
-// CHECK-DAG: %[[TMP_c4_i8:.*]] = arith.constant 4 : i8
-// CHECK-DAG: %[[TMP_c8_i8:.*]] = arith.constant 8 : i8
+// CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 1 : i16
+// CHECK-DAG: %[[CompressedDLT:.*]] = arith.constant 2 : i16
// CHECK-DAG: %[[TMP_c3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[TMP_cst:.*]] = arith.constant 0.000000e+00 : f64
// CHECK-DAG: %[[TMP_c1_i32:.*]] = arith.constant 1 : i32
@@ -388,10 +388,10 @@ func.func @concat_mix_dense_perm_dim1_dyn(%arg0: tensor<3x2xf64>, %arg1: tensor<
// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[TMP_c4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[TMP_c5:.*]] = arith.constant 5 : index
-// CHECK-DAG: %[[LvlTypes_0:.*]] = memref.alloca() : memref<2xi8>
-// CHECK-DAG: %[[LvlTypesP_0:.*]] = memref.cast %[[LvlTypes_0]] : memref<2xi8> to memref<?xi8>
-// CHECK-DAG: memref.store %[[TMP_c4_i8]], %[[LvlTypes_0]][%[[TMP_c0]]] : memref<2xi8>
-// CHECK-DAG: memref.store %[[TMP_c4_i8]], %[[LvlTypes_0]][%[[TMP_c1]]] : memref<2xi8>
+// CHECK-DAG: %[[LvlTypes_0:.*]] = memref.alloca() : memref<2xi16>
+// CHECK-DAG: %[[LvlTypesP_0:.*]] = memref.cast %[[LvlTypes_0]] : memref<2xi16> to memref<?xi16>
+// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes_0]][%[[TMP_c0]]] : memref<2xi16>
+// CHECK-DAG: memref.store %[[DenseDLT]], %[[LvlTypes_0]][%[[TMP_c1]]] : memref<2xi16>
// CHECK-DAG: %[[DimSizes_0:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[DimSizesP_0:.*]] = memref.cast %[[DimSizes_0]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: memref.store %[[TMP_c4]], %[[DimSizes_0]][%[[TMP_c0]]] : memref<2xindex>
@@ -417,10 +417,10 @@ func.func @concat_mix_dense_perm_dim1_dyn(%arg0: tensor<3x2xf64>, %arg1: tensor<
// CHECK: }
// CHECK: }
// CHECK: }
-// CHECK-DAG: %[[LvlTypes_1:.*]] = memref.alloca() : memref<2xi8>
-// CHECK-DAG: %[[LvlTypesP_1:.*]] = memref.cast %[[LvlTypes_1]] : memref<2xi8> to memref<?xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes_1]][%[[TMP_c0]]] : memref<2xi8>
-// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes_1]][%[[TMP_c1]]] : memref<2xi8>
+// CHECK-DAG: %[[LvlTypes_1:.*]] = memref.alloca() : memref<2xi16>
+// CHECK-DAG: %[[LvlTypesP_1:.*]] = memref.cast %[[LvlTypes_1]] : memref<2xi16> to memref<?xi16>
+// CHECK-DAG: memref.store %[[CompressedDLT]], %[[LvlTypes_1]][%[[TMP_c0]]] : memref<2xi16>
+// CHECK-DAG: memref.store %[[CompressedDLT]], %[[LvlTypes_1]][%[[TMP_c1]]] : memref<2xi16>
// CHECK-DAG: %[[DimSizes_1:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[DimSizesP_1:.*]] = memref.cast %[[DimSizes_1]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: memref.store %[[TMP_c4]], %[[DimSizes_1]][%[[TMP_c0]]] : memref<2xindex>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
index 7a4989304b5be2d..cc91cad7f57317e 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
@@ -14,11 +14,11 @@
// CHECK-DAG: %[[True:.*]] = arith.constant true
// CHECK-DAG: %[[I100:.*]] = arith.constant 100 : index
// CHECK-DAG: %[[I300:.*]] = arith.constant 300 : index
-// CHECK-DAG: %[[CompressedDLT:.*]] = arith.constant 8 : i8
-// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<2xi8>
-// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<2xi8> to memref<?xi8>
-// CHECK-DAG: memref.store %[[CompressedDLT]], %[[LvlTypes]]{{\[}}%[[I0]]] : memref<2xi8>
-// CHECK-DAG: memref.store %[[CompressedDLT]], %[[LvlTypes]]{{\[}}%[[I1]]] : memref<2xi8>
+// CHECK-DAG: %[[CompressedDLT:.*]] = arith.constant 2 : i16
+// CHECK-DAG: %[[LvlTypes:.*]] = memref.alloca() : memref<2xi16>
+// CHECK-DAG: %[[LvlTypesP:.*]] = memref.cast %[[LvlTypes]] : memref<2xi16> to memref<?xi16>
+// CHECK-DAG: memref.store %[[CompressedDLT]], %[[LvlTypes]]{{\[}}%[[I0]]] : memref<2xi16>
+// CHECK-DAG: memref.store %[[CompressedDLT]], %[[LvlTypes]]{{\[}}%[[I1]]] : memref<2xi16>
// CHECK-DAG: %[[DimSizes:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[DimSizesP:.*]] = memref.cast %[[DimSizes]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: memref.store %[[I100]], %[[DimSizes]]{{\[}}%[[I0]]] : memref<2xindex>
@@ -32,7 +32,7 @@
// CHECK-DAG: memref.store %[[I0]], %[[Iota]]{{\[}}%[[I0]]] : memref<2xindex>
// CHECK-DAG: memref.store %[[I1]], %[[Iota]]{{\[}}%[[I1]]] : memref<2xindex>
// CHECK-DAG: %[[NullPtr:.*]] = llvm.mlir.null : !llvm.ptr<i8>
-// CHECK: %[[VAL_19:.*]] = call @newSparseTensor(%[[DimSizesP]], %[[LvlSizesP]], %[[LvlTypesP]], %[[IotaP]], %[[IotaP]], %[[C0]], %[[C0]], %[[C1]], %[[C0]], %[[NullPtr]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
+// CHECK: %[[VAL_19:.*]] = call @newSparseTensor(%[[DimSizesP]], %[[LvlSizesP]], %[[LvlTypesP]], %[[IotaP]], %[[IotaP]], %[[C0]], %[[C0]], %[[C1]], %[[C0]], %[[NullPtr]]) : (memref<?xindex>, memref<?xindex>, memref<?xi16>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK: %[[VAL_20:.*]] = memref.alloc() : memref<300xf64>
// CHECK: %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<300xf64> to memref<?xf64>
// CHECK: %[[VAL_22:.*]] = memref.alloc() : memref<300xi1>
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index e1048edce184a51..0548c31828dbb48 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -28,7 +28,7 @@ def testEncodingAttr1D():
# CHECK: equal: True
print(f"equal: {casted == parsed}")
- # CHECK: lvl_types: [<DimLevelType.compressed: 8>]
+ # CHECK: lvl_types: [<DimLevelType.compressed: 2>]
print(f"lvl_types: {casted.lvl_types}")
# CHECK: dim_to_lvl: None
print(f"dim_to_lvl: {casted.dim_to_lvl}")
@@ -68,7 +68,7 @@ def testEncodingAttr2D():
# CHECK: equal: True
print(f"equal: {casted == parsed}")
- # CHECK: lvl_types: [<DimLevelType.dense: 4>, <DimLevelType.compressed: 8>]
+ # CHECK: lvl_types: [<DimLevelType.dense: 1>, <DimLevelType.compressed: 2>]
print(f"lvl_types: {casted.lvl_types}")
# CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
print(f"dim_to_lvl: {casted.dim_to_lvl}")
>From 6d4116e8fbcbabbd278e2847f596e92650f41e26 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Fri, 22 Sep 2023 20:38:57 +0000
Subject: [PATCH 2/2] modified some static assert
---
.../mlir/Dialect/SparseTensor/IR/Enums.h | 24 ++++++++++++-------
.../SparseTensor/IR/Detail/LvlTypeParser.cpp | 2 +-
2 files changed, 17 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 3e24b0be3069d06..565f139cc431f00 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -398,13 +398,17 @@ static_assert((buildLevelType(LevelFormat::Dense,
*buildLevelType(LevelFormat::Dense,
toPropertyBits(true, true)) ==
DimLevelType::Dense &&
- *buildLevelType(LevelFormat::Compressed, true, true) ==
+ *buildLevelType(LevelFormat::Compressed,
+ toPropertyBits(true, true)) ==
DimLevelType::Compressed &&
- *buildLevelType(LevelFormat::Compressed, true, false) ==
+ *buildLevelType(LevelFormat::Compressed,
+ toPropertyBits(true, false)) ==
DimLevelType::CompressedNu &&
- *buildLevelType(LevelFormat::Compressed, false, true) ==
+ *buildLevelType(LevelFormat::Compressed,
+ toPropertyBits(false, true)) ==
DimLevelType::CompressedNo &&
- *buildLevelType(LevelFormat::Compressed, false, false) ==
+ *buildLevelType(LevelFormat::Compressed,
+ toPropertyBits(false, false)) ==
DimLevelType::CompressedNuNo &&
*buildLevelType(LevelFormat::Compressed,
toPropertyBits(true, true, true)) ==
@@ -421,13 +425,17 @@ static_assert((buildLevelType(LevelFormat::Dense,
*buildLevelType(LevelFormat::Compressed,
toPropertyBits(true, true, false, true)) ==
DimLevelType::TwoOutOfFour &&
- *buildLevelType(LevelFormat::Singleton, true, true) ==
+ *buildLevelType(LevelFormat::Singleton,
+ toPropertyBits(true, true)) ==
DimLevelType::Singleton &&
- *buildLevelType(LevelFormat::Singleton, true, false) ==
+ *buildLevelType(LevelFormat::Singleton,
+ toPropertyBits(true, false)) ==
DimLevelType::SingletonNu &&
- *buildLevelType(LevelFormat::Singleton, false, true) ==
+ *buildLevelType(LevelFormat::Singleton,
+ toPropertyBits(false, true)) ==
DimLevelType::SingletonNo &&
- *buildLevelType(LevelFormat::Singleton, false, false) ==
+ *buildLevelType(LevelFormat::Singleton,
+ toPropertyBits(false, false)) ==
DimLevelType::SingletonNuNo),
"buildLevelType conversion is broken");
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index 85468023fa1b25c..55a13c572ea094d 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -52,7 +52,7 @@ FailureOr<uint16_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
const auto loc = parser.getCurrentLocation();
ERROR_IF(failed(parser.parseOptionalKeyword(&base)),
"expected valid level format (e.g. dense, compressed or singleton)")
- uint8_t properties = 0;
+ uint16_t properties = 0;
ParseResult res = parser.parseCommaSeparatedList(
mlir::OpAsmParser::Delimiter::OptionalParen,
More information about the Mlir-commits
mailing list