[Mlir-commits] [mlir] b958954 - [mlir][sparse] introduce a new compressed(hi) dimension level type
Peiming Liu
llvmlistbot at llvm.org
Tue Apr 18 16:26:19 PDT 2023
Author: Peiming Liu
Date: 2023-04-18T23:26:11Z
New Revision: b9589545c45b8b42a7434075151a4b1d4b798a70
URL: https://github.com/llvm/llvm-project/commit/b9589545c45b8b42a7434075151a4b1d4b798a70
DIFF: https://github.com/llvm/llvm-project/commit/b9589545c45b8b42a7434075151a4b1d4b798a70.diff
LOG: [mlir][sparse] introduce a new compressed(hi) dimension level type
`compressed(hi)` is similar to `compressed`, but instead of reusing the previous position high as the current position low, it uses a pair of positions for each sparse index.
The patch only introduces the definition (syntax) but does not provide codegen implementation.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D148664
Added:
Modified:
mlir/include/mlir-c/Dialect/SparseTensor.h
mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
mlir/lib/Bindings/Python/DialectSparseTensor.cpp
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 7d560dd80a90a..8a6763b6ca89e 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -26,15 +26,19 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor);
/// If updating, keep them in sync and update the static_assert in the impl
/// file.
enum MlirSparseTensorDimLevelType {
- MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE = 4, // 0b001_00
- MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED = 8, // 0b010_00
- MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU = 9, // 0b010_01
- MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO = 10, // 0b010_10
- MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO = 11, // 0b010_11
- MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON = 16, // 0b100_00
- MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU = 17, // 0b100_01
- MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO = 18, // 0b100_10
- MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO = 19, // 0b100_11
+ MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE = 4, // 0b0001_00
+ MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED = 8, // 0b0010_00
+ MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU = 9, // 0b0010_01
+ MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO = 10, // 0b0010_10
+ MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO = 11, // 0b0010_11
+ MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON = 16, // 0b0100_00
+ MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU = 17, // 0b0100_01
+ MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO = 18, // 0b0100_10
+ MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO = 19, // 0b0100_11
+ MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI = 32, // 0b1000_00
+ MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NU = 33, // 0b1000_01
+ MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NO = 34, // 0b1000_10
+ MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NU_NO = 35, // 0b1000_11
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 89065b15bf230..acb543adb81b8 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -172,24 +172,29 @@ enum class Action : uint32_t {
/// It should not be used externally, since it does not indicate an
/// actual/representable format.
enum class DimLevelType : uint8_t {
- Undef = 0, // 0b000_00
- Dense = 4, // 0b001_00
- Compressed = 8, // 0b010_00
- CompressedNu = 9, // 0b010_01
- CompressedNo = 10, // 0b010_10
- CompressedNuNo = 11, // 0b010_11
- Singleton = 16, // 0b100_00
- SingletonNu = 17, // 0b100_01
- SingletonNo = 18, // 0b100_10
- SingletonNuNo = 19, // 0b100_11
+ Undef = 0, // 0b0000_00
+ Dense = 4, // 0b0001_00
+ Compressed = 8, // 0b0010_00
+ CompressedNu = 9, // 0b0010_01
+ CompressedNo = 10, // 0b0010_10
+ CompressedNuNo = 11, // 0b0010_11
+ Singleton = 16, // 0b0100_00
+ SingletonNu = 17, // 0b0100_01
+ SingletonNo = 18, // 0b0100_10
+ SingletonNuNo = 19, // 0b0100_11
+ CompressedWithHi = 32, // 0b1000_00
+ CompressedWithHiNu = 33, // 0b1000_01
+ CompressedWithHiNo = 34, // 0b1000_10
+ CompressedWithHiNuNo = 35, // 0b1000_11
};
/// This enum defines all the storage formats supported by the sparse compiler,
/// without the level properties.
enum class LevelFormat : uint8_t {
- Dense = 4, // 0b001_00
- Compressed = 8, // 0b010_00
- Singleton = 16, // 0b100_00
+ Dense = 4, // 0b0001_00
+ Compressed = 8, // 0b0010_00
+ Singleton = 16, // 0b0100_00
+ CompressedWithHi = 32, // 0b1000_00
};
/// Returns string representation of the given dimension level type.
@@ -216,6 +221,14 @@ inline std::string toMLIRString(DimLevelType dlt) {
return "singleton-no";
case DimLevelType::SingletonNuNo:
return "singleton-nu-no";
+ case DimLevelType::CompressedWithHi:
+ return "compressed-hi";
+ case DimLevelType::CompressedWithHiNu:
+ return "compressed-hi-nu";
+ case DimLevelType::CompressedWithHiNo:
+ return "compressed-hi-no";
+ case DimLevelType::CompressedWithHiNuNo:
+ return "compressed-hi-nu-no";
}
return "";
}
@@ -226,8 +239,9 @@ constexpr bool isValidDLT(DimLevelType dlt) {
const uint8_t propertyBits = static_cast<uint8_t>(dlt) & 3;
// If undefined or dense, then must be unique and ordered.
// Otherwise, the format must be one of the known ones.
- return (formatBits <= 1) ? (propertyBits == 0)
- : (formatBits == 2 || formatBits == 4);
+ return (formatBits <= 1)
+ ? (propertyBits == 0)
+ : (formatBits == 2 || formatBits == 4 || formatBits == 8);
}
/// Check if the `DimLevelType` is the special undefined value.
@@ -250,6 +264,12 @@ constexpr bool isCompressedDLT(DimLevelType dlt) {
static_cast<uint8_t>(DimLevelType::Compressed);
}
+/// Check if the `DimLevelType` is compressed (regardless of properties).
+constexpr bool isCompressedWithHiDLT(DimLevelType dlt) {
+ return (static_cast<uint8_t>(dlt) & ~3) ==
+ static_cast<uint8_t>(DimLevelType::CompressedWithHi);
+}
+
/// Check if the `DimLevelType` is singleton (regardless of properties).
constexpr bool isSingletonDLT(DimLevelType dlt) {
return (static_cast<uint8_t>(dlt) & ~3) ==
@@ -333,7 +353,11 @@ static_assert((isValidDLT(DimLevelType::Undef) &&
isValidDLT(DimLevelType::Singleton) &&
isValidDLT(DimLevelType::SingletonNu) &&
isValidDLT(DimLevelType::SingletonNo) &&
- isValidDLT(DimLevelType::SingletonNuNo)),
+ isValidDLT(DimLevelType::SingletonNuNo) &&
+ isValidDLT(DimLevelType::CompressedWithHi) &&
+ isValidDLT(DimLevelType::CompressedWithHiNu) &&
+ isValidDLT(DimLevelType::CompressedWithHiNo) &&
+ isValidDLT(DimLevelType::CompressedWithHiNuNo)),
"isValidDLT definition is broken");
static_assert((!isCompressedDLT(DimLevelType::Dense) &&
@@ -347,6 +371,17 @@ static_assert((!isCompressedDLT(DimLevelType::Dense) &&
!isCompressedDLT(DimLevelType::SingletonNuNo)),
"isCompressedDLT definition is broken");
+static_assert((!isCompressedWithHiDLT(DimLevelType::Dense) &&
+ isCompressedWithHiDLT(DimLevelType::CompressedWithHi) &&
+ isCompressedWithHiDLT(DimLevelType::CompressedWithHiNu) &&
+ isCompressedWithHiDLT(DimLevelType::CompressedWithHiNo) &&
+ isCompressedWithHiDLT(DimLevelType::CompressedWithHiNuNo) &&
+ !isCompressedWithHiDLT(DimLevelType::Singleton) &&
+ !isCompressedWithHiDLT(DimLevelType::SingletonNu) &&
+ !isCompressedWithHiDLT(DimLevelType::SingletonNo) &&
+ !isCompressedWithHiDLT(DimLevelType::SingletonNuNo)),
+ "isCompressedWithHiDLT definition is broken");
+
static_assert((!isSingletonDLT(DimLevelType::Dense) &&
!isSingletonDLT(DimLevelType::Compressed) &&
!isSingletonDLT(DimLevelType::CompressedNu) &&
@@ -366,7 +401,11 @@ static_assert((isOrderedDLT(DimLevelType::Dense) &&
isOrderedDLT(DimLevelType::Singleton) &&
isOrderedDLT(DimLevelType::SingletonNu) &&
!isOrderedDLT(DimLevelType::SingletonNo) &&
- !isOrderedDLT(DimLevelType::SingletonNuNo)),
+ !isOrderedDLT(DimLevelType::SingletonNuNo) &&
+ isOrderedDLT(DimLevelType::CompressedWithHi) &&
+ isOrderedDLT(DimLevelType::CompressedWithHiNu) &&
+ !isOrderedDLT(DimLevelType::CompressedWithHiNo) &&
+ !isOrderedDLT(DimLevelType::CompressedWithHiNuNo)),
"isOrderedDLT definition is broken");
static_assert((isUniqueDLT(DimLevelType::Dense) &&
@@ -377,7 +416,11 @@ static_assert((isUniqueDLT(DimLevelType::Dense) &&
isUniqueDLT(DimLevelType::Singleton) &&
!isUniqueDLT(DimLevelType::SingletonNu) &&
isUniqueDLT(DimLevelType::SingletonNo) &&
- !isUniqueDLT(DimLevelType::SingletonNuNo)),
+ !isUniqueDLT(DimLevelType::SingletonNuNo) &&
+ isUniqueDLT(DimLevelType::CompressedWithHi) &&
+ !isUniqueDLT(DimLevelType::CompressedWithHiNu) &&
+ isUniqueDLT(DimLevelType::CompressedWithHiNo) &&
+ !isUniqueDLT(DimLevelType::CompressedWithHiNuNo)),
"isUniqueDLT definition is broken");
} // namespace sparse_tensor
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index a5b96a86596eb..5ba7ff8744795 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -337,6 +337,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
bool isDenseLvl(::mlir::sparse_tensor::Level l) const { return isDenseDLT(getLvlType(l)); }
bool isCompressedLvl(::mlir::sparse_tensor::Level l) const { return isCompressedDLT(getLvlType(l)); }
+ bool isCompressedWithHiLvl(::mlir::sparse_tensor::Level l) const { return isCompressedWithHiDLT(getLvlType(l)); }
bool isSingletonLvl(::mlir::sparse_tensor::Level l) const { return isSingletonDLT(getLvlType(l)); }
bool isOrderedLvl(::mlir::sparse_tensor::Level l) const { return isOrderedDLT(getLvlType(l)); }
bool isUniqueLvl(::mlir::sparse_tensor::Level l) const { return isUniqueDLT(getLvlType(l)); }
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index e84937df89ba1..0e07f256344f9 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -26,7 +26,14 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
.value("singleton", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON)
.value("singleton-nu", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU)
.value("singleton-no", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO)
- .value("singleton-nu-no", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO);
+ .value("singleton-nu-no", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO)
+ .value("compressed-hi", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI)
+ .value("compressed-hi-nu",
+ MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NU)
+ .value("compressed-hi-no",
+ MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NO)
+ .value("compressed-hi-nu-no",
+ MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NU_NO);
mlir_attribute_subclass(m, "EncodingAttr",
mlirAttributeIsASparseTensorEncodingAttr)
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 11e6ac81b7e14..1a93b14f780dd 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -198,12 +198,19 @@ SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
return getStaticDimSliceStride(toOrigDim(*this, lvl));
}
-const static DimLevelType validDLTs[] = {
- DimLevelType::Dense, DimLevelType::Compressed,
- DimLevelType::CompressedNu, DimLevelType::CompressedNo,
- DimLevelType::CompressedNuNo, DimLevelType::Singleton,
- DimLevelType::SingletonNu, DimLevelType::SingletonNo,
- DimLevelType::SingletonNuNo};
+const static DimLevelType validDLTs[] = {DimLevelType::Dense,
+ DimLevelType::Compressed,
+ DimLevelType::CompressedNu,
+ DimLevelType::CompressedNo,
+ DimLevelType::CompressedNuNo,
+ DimLevelType::Singleton,
+ DimLevelType::SingletonNu,
+ DimLevelType::SingletonNo,
+ DimLevelType::SingletonNuNo,
+ DimLevelType::CompressedWithHi,
+ DimLevelType::CompressedWithHiNu,
+ DimLevelType::CompressedWithHiNo,
+ DimLevelType::CompressedWithHiNuNo};
static std::optional<DimLevelType> parseDLT(StringRef str) {
for (DimLevelType dlt : validDLTs)
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index d2d6fab60f124..087ce42f18779 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -55,6 +55,16 @@ func.func private @sparse_coo(tensor<?x?xf32, #COO>)
// -----
+#BCOO = #sparse_tensor.encoding<{
+ dimLevelType = [ "dense", "compressed-hi-nu", "singleton" ]
+}>
+
+// CHECK-LABEL: func private @sparse_bcoo(
+// CHECK-SAME: tensor<?x?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed-hi-nu", "singleton" ] }>>)
+func.func private @sparse_bcoo(tensor<?x?x?xf32, #BCOO>)
+
+// -----
+
#SortedCOO = #sparse_tensor.encoding<{
dimLevelType = [ "compressed-nu", "singleton" ]
}>
More information about the Mlir-commits
mailing list