[Mlir-commits] [mlir] 429919e - [mlir][sparse][pybind][CAPI] remove LevelType enum from CAPI, constru… (#81682)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 13 16:45:26 PST 2024
Author: Peiming Liu
Date: 2024-02-13T16:45:22-08:00
New Revision: 429919e32823ad735a19ab385f37e313512cedde
URL: https://github.com/llvm/llvm-project/commit/429919e32823ad735a19ab385f37e313512cedde
DIFF: https://github.com/llvm/llvm-project/commit/429919e32823ad735a19ab385f37e313512cedde.diff
LOG: [mlir][sparse][pybind][CAPI] remove LevelType enum from CAPI, constru… (#81682)
…ct LevelType from LevelFormat and properties instead.
**Rationale**
We used to explicitly declare every possible combination between
`LevelFormat` and `LevelProperties`, and it now becomes difficult to
scale as more properties/level formats are going to be introduced.
Added:
Modified:
mlir/include/mlir-c/Dialect/SparseTensor.h
mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
mlir/lib/Bindings/Python/DialectSparseTensor.cpp
mlir/lib/CAPI/Dialect/SparseTensor.cpp
mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py
mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py
mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
mlir/test/python/dialects/sparse_tensor/dialect.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index d549f5dddc1318..898d2f12779e39 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -27,23 +27,19 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor);
/// file.
typedef uint64_t MlirSparseTensorLevelType;
-enum MlirBaseSparseTensorLevelType {
+enum MlirSparseTensorLevelFormat {
MLIR_SPARSE_TENSOR_LEVEL_DENSE = 0x000000010000,
MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 0x000000020000,
- MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 0x000000020001,
- MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 0x000000020002,
- MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 0x000000020003,
MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 0x000000040000,
- MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 0x000000040001,
- MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 0x000000040002,
- MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 0x000000040003,
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 0x000000080000,
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 0x000000080001,
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 0x000000080002,
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 0x000000080003,
MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000100000,
};
+enum MlirSparseTensorLevelPropertyNondefault {
+ MLIR_SPARSE_PROPERTY_NON_UNIQUE = 0x0001,
+ MLIR_SPARSE_PROPERTY_NON_ORDERED = 0x0002,
+};
+
//===----------------------------------------------------------------------===//
// SparseTensorEncodingAttr
//===----------------------------------------------------------------------===//
@@ -66,6 +62,10 @@ mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirSparseTensorLevelType
mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl);
+/// Returns a specified level-format of the `sparse_tensor.encoding` attribute.
+MLIR_CAPI_EXPORTED enum MlirSparseTensorLevelFormat
+mlirSparseTensorEncodingAttrGetLvlFmt(MlirAttribute attr, intptr_t lvl);
+
/// Returns the dimension-to-level mapping of the `sparse_tensor.encoding`
/// attribute.
MLIR_CAPI_EXPORTED MlirAffineMap
@@ -92,7 +92,9 @@ mlirSparseTensorEncodingAttrGetStructuredM(MlirSparseTensorLevelType lvlType);
MLIR_CAPI_EXPORTED MlirSparseTensorLevelType
mlirSparseTensorEncodingAttrBuildLvlType(
- enum MlirBaseSparseTensorLevelType lvlType, unsigned n, unsigned m);
+ enum MlirSparseTensorLevelFormat lvlFmt,
+ const enum MlirSparseTensorLevelPropertyNondefault *properties,
+ unsigned propSize, unsigned n, unsigned m);
#ifdef __cplusplus
}
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index e940d203be9ed5..74cc0dee554a17 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -35,6 +35,7 @@
#include <cinttypes>
#include <complex>
#include <optional>
+#include <vector>
namespace mlir {
namespace sparse_tensor {
@@ -343,17 +344,31 @@ constexpr std::optional<LevelFormat> getLevelFormat(LevelType lt) {
/// Convert a LevelFormat to its corresponding LevelType with the given
/// properties. Returns std::nullopt when the properties are not applicable
/// for the input level format.
-constexpr std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
- bool unique, uint64_t n = 0,
- uint64_t m = 0) {
+inline std::optional<LevelType>
+buildLevelType(LevelFormat lf,
+ const std::vector<LevelPropertyNondefault> &properties,
+ uint64_t n = 0, uint64_t m = 0) {
uint64_t newN = n << 32;
uint64_t newM = m << 40;
- auto lt =
- static_cast<LevelType>(static_cast<uint64_t>(lf) | (ordered ? 0 : 2) |
- (unique ? 0 : 1) | newN | newM);
+ uint64_t ltInt = static_cast<uint64_t>(lf) | newN | newM;
+ for (auto p : properties) {
+ ltInt |= static_cast<uint64_t>(p);
+ }
+ auto lt = static_cast<LevelType>(ltInt);
return isValidLT(lt) ? std::optional(lt) : std::nullopt;
}
+inline std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
+ bool unique, uint64_t n = 0,
+ uint64_t m = 0) {
+ std::vector<LevelPropertyNondefault> properties;
+ if (!ordered)
+ properties.push_back(LevelPropertyNondefault::Nonordered);
+ if (!unique)
+ properties.push_back(LevelPropertyNondefault::Nonunique);
+ return buildLevelType(lf, properties, n, m);
+}
+
//
// Ensure the above methods work as intended.
//
@@ -380,57 +395,6 @@ static_assert(
*getLevelFormat(LevelType::NOutOfM) == LevelFormat::NOutOfM),
"getLevelFormat conversion is broken");
-static_assert(
- (buildLevelType(LevelFormat::Dense, false, true) == std::nullopt &&
- buildLevelType(LevelFormat::Dense, true, false) == std::nullopt &&
- buildLevelType(LevelFormat::Dense, false, false) == std::nullopt &&
- *buildLevelType(LevelFormat::Dense, true, true) == LevelType::Dense &&
- *buildLevelType(LevelFormat::Compressed, true, true) ==
- LevelType::Compressed &&
- *buildLevelType(LevelFormat::Compressed, true, false) ==
- LevelType::CompressedNu &&
- *buildLevelType(LevelFormat::Compressed, false, true) ==
- LevelType::CompressedNo &&
- *buildLevelType(LevelFormat::Compressed, false, false) ==
- LevelType::CompressedNuNo &&
- *buildLevelType(LevelFormat::Singleton, true, true) ==
- LevelType::Singleton &&
- *buildLevelType(LevelFormat::Singleton, true, false) ==
- LevelType::SingletonNu &&
- *buildLevelType(LevelFormat::Singleton, false, true) ==
- LevelType::SingletonNo &&
- *buildLevelType(LevelFormat::Singleton, false, false) ==
- LevelType::SingletonNuNo &&
- *buildLevelType(LevelFormat::LooseCompressed, true, true) ==
- LevelType::LooseCompressed &&
- *buildLevelType(LevelFormat::LooseCompressed, true, false) ==
- LevelType::LooseCompressedNu &&
- *buildLevelType(LevelFormat::LooseCompressed, false, true) ==
- LevelType::LooseCompressedNo &&
- *buildLevelType(LevelFormat::LooseCompressed, false, false) ==
- LevelType::LooseCompressedNuNo &&
- buildLevelType(LevelFormat::NOutOfM, false, true) == std::nullopt &&
- buildLevelType(LevelFormat::NOutOfM, true, false) == std::nullopt &&
- buildLevelType(LevelFormat::NOutOfM, false, false) == std::nullopt &&
- *buildLevelType(LevelFormat::NOutOfM, true, true) == LevelType::NOutOfM),
- "buildLevelType conversion is broken");
-
-static_assert(
- (getN(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4)) == 2 &&
- getM(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4)) == 4 &&
- getN(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10)) == 8 &&
- getM(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10)) == 10),
- "getN/M conversion is broken");
-
-static_assert(
- (isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4),
- 2, 4) &&
- isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10),
- 8, 10) &&
- !isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 3, 4),
- 2, 4)),
- "isValidNOutOfMLT definition is broken");
-
static_assert(
(isValidLT(LevelType::Undef) && isValidLT(LevelType::Dense) &&
isValidLT(LevelType::Compressed) && isValidLT(LevelType::CompressedNu) &&
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index 74f4d2413a6993..171faf9e008746 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -23,24 +23,17 @@ using namespace mlir;
using namespace mlir::python::adaptors;
static void populateDialectSparseTensorSubmodule(const py::module &m) {
- py::enum_<MlirBaseSparseTensorLevelType>(m, "LevelType", py::module_local())
+ py::enum_<MlirSparseTensorLevelFormat>(m, "LevelFormat", py::module_local())
.value("dense", MLIR_SPARSE_TENSOR_LEVEL_DENSE)
.value("n_out_of_m", MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M)
.value("compressed", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED)
- .value("compressed_nu", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU)
- .value("compressed_no", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO)
- .value("compressed_nu_no", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO)
.value("singleton", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON)
- .value("singleton_nu", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU)
- .value("singleton_no", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO)
- .value("singleton_nu_no", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO)
- .value("loose_compressed", MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED)
- .value("loose_compressed_nu",
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU)
- .value("loose_compressed_no",
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO)
- .value("loose_compressed_nu_no",
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO);
+ .value("loose_compressed", MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED);
+
+ py::enum_<MlirSparseTensorLevelPropertyNondefault>(m, "LevelProperty",
+ py::module_local())
+ .value("non_ordered", MLIR_SPARSE_PROPERTY_NON_ORDERED)
+ .value("non_unique", MLIR_SPARSE_PROPERTY_NON_UNIQUE);
mlir_attribute_subclass(m, "EncodingAttr",
mlirAttributeIsASparseTensorEncodingAttr)
@@ -62,12 +55,17 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
"Gets a sparse_tensor.encoding from parameters.")
.def_classmethod(
"build_level_type",
- [](py::object cls, MlirBaseSparseTensorLevelType lvlType, unsigned n,
- unsigned m) {
- return mlirSparseTensorEncodingAttrBuildLvlType(lvlType, n, m);
+ [](py::object cls, MlirSparseTensorLevelFormat lvlFmt,
+ const std::vector<MlirSparseTensorLevelPropertyNondefault>
+ &properties,
+ unsigned n, unsigned m) {
+ return mlirSparseTensorEncodingAttrBuildLvlType(
+ lvlFmt, properties.data(), properties.size(), n, m);
},
- py::arg("cls"), py::arg("lvl_type"), py::arg("n") = 0,
- py::arg("m") = 0,
+ py::arg("cls"), py::arg("lvl_fmt"),
+ py::arg("properties") =
+ std::vector<MlirSparseTensorLevelPropertyNondefault>(),
+ py::arg("n") = 0, py::arg("m") = 0,
"Builds a sparse_tensor.encoding.level_type from parameters.")
.def_property_readonly(
"lvl_types",
@@ -113,17 +111,12 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
return mlirSparseTensorEncodingAttrGetStructuredM(
mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1));
})
- .def_property_readonly("lvl_types_enum", [](MlirAttribute self) {
+ .def_property_readonly("lvl_formats_enum", [](MlirAttribute self) {
const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
- std::vector<MlirBaseSparseTensorLevelType> ret;
+ std::vector<MlirSparseTensorLevelFormat> ret;
ret.reserve(lvlRank);
- for (int l = 0; l < lvlRank; l++) {
- // Convert level type to 32 bits to ignore n and m for n_out_of_m
- // format.
- ret.push_back(
- static_cast<MlirBaseSparseTensorLevelType>(static_cast<uint32_t>(
- mlirSparseTensorEncodingAttrGetLvlType(self, l))));
- }
+ for (int l = 0; l < lvlRank; l++)
+ ret.push_back(mlirSparseTensorEncodingAttrGetLvlFmt(self, l));
return ret;
});
}
diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
index 4e1bd45863fdac..55af8becbba20e 100644
--- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp
+++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
@@ -22,34 +22,23 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor,
// Ensure the C-API enums are int-castable to C++ equivalents.
static_assert(
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_DENSE) ==
- static_cast<int>(LevelType::Dense) &&
+ static_cast<int>(LevelFormat::Dense) &&
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) ==
- static_cast<int>(LevelType::Compressed) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU) ==
- static_cast<int>(LevelType::CompressedNu) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO) ==
- static_cast<int>(LevelType::CompressedNo) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO) ==
- static_cast<int>(LevelType::CompressedNuNo) &&
+ static_cast<int>(LevelFormat::Compressed) &&
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) ==
- static_cast<int>(LevelType::Singleton) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU) ==
- static_cast<int>(LevelType::SingletonNu) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO) ==
- static_cast<int>(LevelType::SingletonNo) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO) ==
- static_cast<int>(LevelType::SingletonNuNo) &&
+ static_cast<int>(LevelFormat::Singleton) &&
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED) ==
- static_cast<int>(LevelType::LooseCompressed) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU) ==
- static_cast<int>(LevelType::LooseCompressedNu) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO) ==
- static_cast<int>(LevelType::LooseCompressedNo) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO) ==
- static_cast<int>(LevelType::LooseCompressedNuNo) &&
+ static_cast<int>(LevelFormat::LooseCompressed) &&
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M) ==
- static_cast<int>(LevelType::NOutOfM),
- "MlirSparseTensorLevelType (C-API) and LevelType (C++) mismatch");
+ static_cast<int>(LevelFormat::NOutOfM),
+ "MlirSparseTensorLevelFormat (C-API) and LevelFormat (C++) mismatch");
+
+static_assert(static_cast<int>(MLIR_SPARSE_PROPERTY_NON_ORDERED) ==
+ static_cast<int>(LevelPropertyNondefault::Nonordered) &&
+ static_cast<int>(MLIR_SPARSE_PROPERTY_NON_UNIQUE) ==
+ static_cast<int>(LevelPropertyNondefault::Nonunique),
+ "MlirSparseTensorLevelProperty (C-API) and "
+ "LevelPropertyNondefault (C++) mismatch");
bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
return isa<SparseTensorEncodingAttr>(unwrap(attr));
@@ -87,6 +76,13 @@ mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl) {
cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlType(lvl));
}
+enum MlirSparseTensorLevelFormat
+mlirSparseTensorEncodingAttrGetLvlFmt(MlirAttribute attr, intptr_t lvl) {
+ LevelType lt =
+ static_cast<LevelType>(mlirSparseTensorEncodingAttrGetLvlType(attr, lvl));
+ return static_cast<MlirSparseTensorLevelFormat>(*getLevelFormat(lt));
+}
+
int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) {
return cast<SparseTensorEncodingAttr>(unwrap(attr)).getPosWidth();
}
@@ -95,12 +91,17 @@ int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr) {
return cast<SparseTensorEncodingAttr>(unwrap(attr)).getCrdWidth();
}
-MlirSparseTensorLevelType
-mlirSparseTensorEncodingAttrBuildLvlType(MlirBaseSparseTensorLevelType lvlType,
- unsigned n, unsigned m) {
- LevelType lt = static_cast<LevelType>(lvlType);
- return static_cast<MlirSparseTensorLevelType>(*buildLevelType(
- *getLevelFormat(lt), isOrderedLT(lt), isUniqueLT(lt), n, m));
+MlirSparseTensorLevelType mlirSparseTensorEncodingAttrBuildLvlType(
+ enum MlirSparseTensorLevelFormat lvlFmt,
+ const enum MlirSparseTensorLevelPropertyNondefault *properties,
+ unsigned size, unsigned n, unsigned m) {
+
+ std::vector<LevelPropertyNondefault> props;
+ for (unsigned i = 0; i < size; i++)
+ props.push_back(static_cast<LevelPropertyNondefault>(properties[i]));
+
+ return static_cast<MlirSparseTensorLevelType>(
+ *buildLevelType(static_cast<LevelFormat>(lvlFmt), props, n, m));
}
unsigned
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py
index 199777c79ef838..e2050b98728f21 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py
@@ -139,12 +139,15 @@ def main():
# search the full state space to reduce runtime of the test. It is
# straightforward to adapt the code below to explore more combinations.
# For these simple orderings, dim2lvl and lvl2dim are the same.
+ builder = st.EncodingAttr.build_level_type
+ fmt = st.LevelFormat
+ prop = st.LevelProperty
levels = [
- [st.LevelType.compressed_nu, st.LevelType.singleton],
- [st.LevelType.dense, st.LevelType.dense],
- [st.LevelType.dense, st.LevelType.compressed],
- [st.LevelType.compressed, st.LevelType.dense],
- [st.LevelType.compressed, st.LevelType.compressed],
+ [builder(fmt.compressed, [prop.non_unique]), builder(fmt.singleton)],
+ [builder(fmt.dense), builder(fmt.dense)],
+ [builder(fmt.dense), builder(fmt.compressed)],
+ [builder(fmt.compressed), builder(fmt.dense)],
+ [builder(fmt.compressed), builder(fmt.compressed)],
]
orderings = [
ir.AffineMap.get_permutation([0, 1]),
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py
index 0aa4f92a7bf4ef..e7354c24d619e0 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py
@@ -125,12 +125,15 @@ def main():
vl = 1
e = False
opt = f"parallelization-strategy=none"
+ builder = st.EncodingAttr.build_level_type
+ fmt = st.LevelFormat
+ prop = st.LevelProperty
levels = [
- [st.LevelType.compressed_nu, st.LevelType.singleton],
- [st.LevelType.dense, st.LevelType.dense],
- [st.LevelType.dense, st.LevelType.compressed],
- [st.LevelType.compressed, st.LevelType.dense],
- [st.LevelType.compressed, st.LevelType.compressed],
+ [builder(fmt.compressed, [prop.non_unique]), builder(fmt.singleton)],
+ [builder(fmt.dense), builder(fmt.dense)],
+ [builder(fmt.dense), builder(fmt.compressed)],
+ [builder(fmt.compressed), builder(fmt.dense)],
+ [builder(fmt.compressed), builder(fmt.compressed)],
]
orderings = [
ir.AffineMap.get_permutation([0, 1]),
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
index d994e8d0a8a19d..7da05303c7e1e1 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
@@ -124,11 +124,14 @@ def main():
# Loop over various sparse types (COO, CSR, DCSR, CSC, DCSC) with
# regular and loose compression and various metadata bitwidths.
# For these simple orderings, dim2lvl and lvl2dim are the same.
+ builder = st.EncodingAttr.build_level_type
+ fmt = st.LevelFormat
+ prop = st.LevelProperty
levels = [
- [st.LevelType.compressed_nu, st.LevelType.singleton],
- [st.LevelType.dense, st.LevelType.compressed],
- [st.LevelType.dense, st.LevelType.loose_compressed],
- [st.LevelType.compressed, st.LevelType.compressed],
+ [builder(fmt.compressed, [prop.non_unique]), builder(fmt.singleton)],
+ [builder(fmt.dense), builder(fmt.compressed)],
+ [builder(fmt.dense), builder(fmt.loose_compressed)],
+ [builder(fmt.compressed), builder(fmt.compressed)],
]
orderings = [
(ir.AffineMap.get_permutation([0, 1]), 0),
@@ -149,10 +152,10 @@ def main():
# Now do the same for BSR.
level = [
- st.LevelType.dense,
- st.LevelType.compressed,
- st.LevelType.dense,
- st.LevelType.dense,
+ builder(fmt.dense),
+ builder(fmt.compressed),
+ builder(fmt.dense),
+ builder(fmt.dense),
]
d0 = ir.AffineDimExpr.get(0)
d1 = ir.AffineDimExpr.get(1)
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
index 2b79c1416562dc..ce3516e2edaf03 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
@@ -203,10 +203,10 @@ def main():
shape = range(2, 3)
rank = len(shape)
# All combinations.
+ dense_lvl = st.EncodingAttr.build_level_type(st.LevelFormat.dense)
+ sparse_lvl = st.EncodingAttr.build_level_type(st.LevelFormat.compressed)
levels = list(
- itertools.product(
- *itertools.repeat([st.LevelType.dense, st.LevelType.compressed], rank)
- )
+ itertools.product(*itertools.repeat([dense_lvl, sparse_lvl], rank))
)
# All permutations.
orderings = list(
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 1fa7030ca1be91..2c0603216ef2c2 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -73,8 +73,8 @@ def testEncodingAttrStructure():
# CHECK: lvl_types: [65536, 65536, 4406637494272]
print(f"lvl_types: {casted.lvl_types}")
- # CHECK: lvl_types_enum: [<LevelType.dense: 65536>, <LevelType.dense: 65536>, <LevelType.n_out_of_m: 1048576>]
- print(f"lvl_types_enum: {casted.lvl_types_enum}")
+ # CHECK: lvl_formats_enum: [<LevelFormat.dense: 65536>, <LevelFormat.dense: 65536>, <LevelFormat.n_out_of_m: 1048576>]
+ print(f"lvl_formats_enum: {casted.lvl_formats_enum}")
# CHECK: structured_n: 2
print(f"structured_n: {casted.structured_n}")
# CHECK: structured_m: 4
@@ -96,7 +96,10 @@ def testEncodingAttrStructure():
# CHECK: created_equal: False
print(f"created_equal: {created == casted}")
- built_2_4 = st.EncodingAttr.build_level_type(st.LevelType.n_out_of_m, 2, 4)
+ built_2_4 = st.EncodingAttr.build_level_type(
+ st.LevelFormat.n_out_of_m, [], 2, 4
+ )
+ built_dense = st.EncodingAttr.build_level_type(st.LevelFormat.dense)
dim_to_lvl = AffineMap.get(
2,
0,
@@ -118,7 +121,7 @@ def testEncodingAttrStructure():
],
)
built = st.EncodingAttr.get(
- [st.LevelType.dense, st.LevelType.dense, built_2_4],
+ [built_dense, built_dense, built_2_4],
dim_to_lvl,
lvl_to_dim,
0,
More information about the Mlir-commits
mailing list