[Mlir-commits] [mlir] 429919e - [mlir][sparse][pybind][CAPI] remove LevelType enum from CAPI, constru… (#81682)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 13 16:45:26 PST 2024


Author: Peiming Liu
Date: 2024-02-13T16:45:22-08:00
New Revision: 429919e32823ad735a19ab385f37e313512cedde

URL: https://github.com/llvm/llvm-project/commit/429919e32823ad735a19ab385f37e313512cedde
DIFF: https://github.com/llvm/llvm-project/commit/429919e32823ad735a19ab385f37e313512cedde.diff

LOG: [mlir][sparse][pybind][CAPI] remove LevelType enum from CAPI, constru… (#81682)

…ct LevelType from LevelFormat and properties instead.

**Rationale**
We used to explicitly declare every possible combination between
`LevelFormat` and `LevelProperties`, and it now becomes difficult to
scale as more properties/level formats are going to be introduced.

Added: 
    

Modified: 
    mlir/include/mlir-c/Dialect/SparseTensor.h
    mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
    mlir/lib/Bindings/Python/DialectSparseTensor.cpp
    mlir/lib/CAPI/Dialect/SparseTensor.cpp
    mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py
    mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py
    mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
    mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
    mlir/test/python/dialects/sparse_tensor/dialect.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index d549f5dddc1318..898d2f12779e39 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -27,23 +27,19 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor);
 /// file.
 typedef uint64_t MlirSparseTensorLevelType;
 
-enum MlirBaseSparseTensorLevelType {
+enum MlirSparseTensorLevelFormat {
   MLIR_SPARSE_TENSOR_LEVEL_DENSE = 0x000000010000,
   MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 0x000000020000,
-  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 0x000000020001,
-  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 0x000000020002,
-  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 0x000000020003,
   MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 0x000000040000,
-  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 0x000000040001,
-  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 0x000000040002,
-  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 0x000000040003,
   MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 0x000000080000,
-  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 0x000000080001,
-  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 0x000000080002,
-  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 0x000000080003,
   MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000100000,
 };
 
+enum MlirSparseTensorLevelPropertyNondefault {
+  MLIR_SPARSE_PROPERTY_NON_UNIQUE = 0x0001,
+  MLIR_SPARSE_PROPERTY_NON_ORDERED = 0x0002,
+};
+
 //===----------------------------------------------------------------------===//
 // SparseTensorEncodingAttr
 //===----------------------------------------------------------------------===//
@@ -66,6 +62,10 @@ mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr);
 MLIR_CAPI_EXPORTED MlirSparseTensorLevelType
 mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl);
 
+/// Returns a specified level-format of the `sparse_tensor.encoding` attribute.
+MLIR_CAPI_EXPORTED enum MlirSparseTensorLevelFormat
+mlirSparseTensorEncodingAttrGetLvlFmt(MlirAttribute attr, intptr_t lvl);
+
 /// Returns the dimension-to-level mapping of the `sparse_tensor.encoding`
 /// attribute.
 MLIR_CAPI_EXPORTED MlirAffineMap
@@ -92,7 +92,9 @@ mlirSparseTensorEncodingAttrGetStructuredM(MlirSparseTensorLevelType lvlType);
 
 MLIR_CAPI_EXPORTED MlirSparseTensorLevelType
 mlirSparseTensorEncodingAttrBuildLvlType(
-    enum MlirBaseSparseTensorLevelType lvlType, unsigned n, unsigned m);
+    enum MlirSparseTensorLevelFormat lvlFmt,
+    const enum MlirSparseTensorLevelPropertyNondefault *properties,
+    unsigned propSize, unsigned n, unsigned m);
 
 #ifdef __cplusplus
 }

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index e940d203be9ed5..74cc0dee554a17 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -35,6 +35,7 @@
 #include <cinttypes>
 #include <complex>
 #include <optional>
+#include <vector>
 
 namespace mlir {
 namespace sparse_tensor {
@@ -343,17 +344,31 @@ constexpr std::optional<LevelFormat> getLevelFormat(LevelType lt) {
 /// Convert a LevelFormat to its corresponding LevelType with the given
 /// properties. Returns std::nullopt when the properties are not applicable
 /// for the input level format.
-constexpr std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
-                                                  bool unique, uint64_t n = 0,
-                                                  uint64_t m = 0) {
+inline std::optional<LevelType>
+buildLevelType(LevelFormat lf,
+               const std::vector<LevelPropertyNondefault> &properties,
+               uint64_t n = 0, uint64_t m = 0) {
   uint64_t newN = n << 32;
   uint64_t newM = m << 40;
-  auto lt =
-      static_cast<LevelType>(static_cast<uint64_t>(lf) | (ordered ? 0 : 2) |
-                             (unique ? 0 : 1) | newN | newM);
+  uint64_t ltInt = static_cast<uint64_t>(lf) | newN | newM;
+  for (auto p : properties) {
+    ltInt |= static_cast<uint64_t>(p);
+  }
+  auto lt = static_cast<LevelType>(ltInt);
   return isValidLT(lt) ? std::optional(lt) : std::nullopt;
 }
 
+inline std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
+                                               bool unique, uint64_t n = 0,
+                                               uint64_t m = 0) {
+  std::vector<LevelPropertyNondefault> properties;
+  if (!ordered)
+    properties.push_back(LevelPropertyNondefault::Nonordered);
+  if (!unique)
+    properties.push_back(LevelPropertyNondefault::Nonunique);
+  return buildLevelType(lf, properties, n, m);
+}
+
 //
 // Ensure the above methods work as intended.
 //
@@ -380,57 +395,6 @@ static_assert(
      *getLevelFormat(LevelType::NOutOfM) == LevelFormat::NOutOfM),
     "getLevelFormat conversion is broken");
 
-static_assert(
-    (buildLevelType(LevelFormat::Dense, false, true) == std::nullopt &&
-     buildLevelType(LevelFormat::Dense, true, false) == std::nullopt &&
-     buildLevelType(LevelFormat::Dense, false, false) == std::nullopt &&
-     *buildLevelType(LevelFormat::Dense, true, true) == LevelType::Dense &&
-     *buildLevelType(LevelFormat::Compressed, true, true) ==
-         LevelType::Compressed &&
-     *buildLevelType(LevelFormat::Compressed, true, false) ==
-         LevelType::CompressedNu &&
-     *buildLevelType(LevelFormat::Compressed, false, true) ==
-         LevelType::CompressedNo &&
-     *buildLevelType(LevelFormat::Compressed, false, false) ==
-         LevelType::CompressedNuNo &&
-     *buildLevelType(LevelFormat::Singleton, true, true) ==
-         LevelType::Singleton &&
-     *buildLevelType(LevelFormat::Singleton, true, false) ==
-         LevelType::SingletonNu &&
-     *buildLevelType(LevelFormat::Singleton, false, true) ==
-         LevelType::SingletonNo &&
-     *buildLevelType(LevelFormat::Singleton, false, false) ==
-         LevelType::SingletonNuNo &&
-     *buildLevelType(LevelFormat::LooseCompressed, true, true) ==
-         LevelType::LooseCompressed &&
-     *buildLevelType(LevelFormat::LooseCompressed, true, false) ==
-         LevelType::LooseCompressedNu &&
-     *buildLevelType(LevelFormat::LooseCompressed, false, true) ==
-         LevelType::LooseCompressedNo &&
-     *buildLevelType(LevelFormat::LooseCompressed, false, false) ==
-         LevelType::LooseCompressedNuNo &&
-     buildLevelType(LevelFormat::NOutOfM, false, true) == std::nullopt &&
-     buildLevelType(LevelFormat::NOutOfM, true, false) == std::nullopt &&
-     buildLevelType(LevelFormat::NOutOfM, false, false) == std::nullopt &&
-     *buildLevelType(LevelFormat::NOutOfM, true, true) == LevelType::NOutOfM),
-    "buildLevelType conversion is broken");
-
-static_assert(
-    (getN(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4)) == 2 &&
-     getM(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4)) == 4 &&
-     getN(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10)) == 8 &&
-     getM(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10)) == 10),
-    "getN/M conversion is broken");
-
-static_assert(
-    (isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4),
-                      2, 4) &&
-     isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10),
-                      8, 10) &&
-     !isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 3, 4),
-                       2, 4)),
-    "isValidNOutOfMLT definition is broken");
-
 static_assert(
     (isValidLT(LevelType::Undef) && isValidLT(LevelType::Dense) &&
      isValidLT(LevelType::Compressed) && isValidLT(LevelType::CompressedNu) &&

diff  --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index 74f4d2413a6993..171faf9e008746 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -23,24 +23,17 @@ using namespace mlir;
 using namespace mlir::python::adaptors;
 
 static void populateDialectSparseTensorSubmodule(const py::module &m) {
-  py::enum_<MlirBaseSparseTensorLevelType>(m, "LevelType", py::module_local())
+  py::enum_<MlirSparseTensorLevelFormat>(m, "LevelFormat", py::module_local())
       .value("dense", MLIR_SPARSE_TENSOR_LEVEL_DENSE)
       .value("n_out_of_m", MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M)
       .value("compressed", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED)
-      .value("compressed_nu", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU)
-      .value("compressed_no", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO)
-      .value("compressed_nu_no", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO)
       .value("singleton", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON)
-      .value("singleton_nu", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU)
-      .value("singleton_no", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO)
-      .value("singleton_nu_no", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO)
-      .value("loose_compressed", MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED)
-      .value("loose_compressed_nu",
-             MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU)
-      .value("loose_compressed_no",
-             MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO)
-      .value("loose_compressed_nu_no",
-             MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO);
+      .value("loose_compressed", MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED);
+
+  py::enum_<MlirSparseTensorLevelPropertyNondefault>(m, "LevelProperty",
+                                                     py::module_local())
+      .value("non_ordered", MLIR_SPARSE_PROPERTY_NON_ORDERED)
+      .value("non_unique", MLIR_SPARSE_PROPERTY_NON_UNIQUE);
 
   mlir_attribute_subclass(m, "EncodingAttr",
                           mlirAttributeIsASparseTensorEncodingAttr)
@@ -62,12 +55,17 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
           "Gets a sparse_tensor.encoding from parameters.")
       .def_classmethod(
           "build_level_type",
-          [](py::object cls, MlirBaseSparseTensorLevelType lvlType, unsigned n,
-             unsigned m) {
-            return mlirSparseTensorEncodingAttrBuildLvlType(lvlType, n, m);
+          [](py::object cls, MlirSparseTensorLevelFormat lvlFmt,
+             const std::vector<MlirSparseTensorLevelPropertyNondefault>
+                 &properties,
+             unsigned n, unsigned m) {
+            return mlirSparseTensorEncodingAttrBuildLvlType(
+                lvlFmt, properties.data(), properties.size(), n, m);
           },
-          py::arg("cls"), py::arg("lvl_type"), py::arg("n") = 0,
-          py::arg("m") = 0,
+          py::arg("cls"), py::arg("lvl_fmt"),
+          py::arg("properties") =
+              std::vector<MlirSparseTensorLevelPropertyNondefault>(),
+          py::arg("n") = 0, py::arg("m") = 0,
           "Builds a sparse_tensor.encoding.level_type from parameters.")
       .def_property_readonly(
           "lvl_types",
@@ -113,17 +111,12 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
             return mlirSparseTensorEncodingAttrGetStructuredM(
                 mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1));
           })
-      .def_property_readonly("lvl_types_enum", [](MlirAttribute self) {
+      .def_property_readonly("lvl_formats_enum", [](MlirAttribute self) {
         const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
-        std::vector<MlirBaseSparseTensorLevelType> ret;
+        std::vector<MlirSparseTensorLevelFormat> ret;
         ret.reserve(lvlRank);
-        for (int l = 0; l < lvlRank; l++) {
-          // Convert level type to 32 bits to ignore n and m for n_out_of_m
-          // format.
-          ret.push_back(
-              static_cast<MlirBaseSparseTensorLevelType>(static_cast<uint32_t>(
-                  mlirSparseTensorEncodingAttrGetLvlType(self, l))));
-        }
+        for (int l = 0; l < lvlRank; l++)
+          ret.push_back(mlirSparseTensorEncodingAttrGetLvlFmt(self, l));
         return ret;
       });
 }

diff  --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
index 4e1bd45863fdac..55af8becbba20e 100644
--- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp
+++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
@@ -22,34 +22,23 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor,
 // Ensure the C-API enums are int-castable to C++ equivalents.
 static_assert(
     static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_DENSE) ==
-            static_cast<int>(LevelType::Dense) &&
+            static_cast<int>(LevelFormat::Dense) &&
         static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) ==
-            static_cast<int>(LevelType::Compressed) &&
-        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU) ==
-            static_cast<int>(LevelType::CompressedNu) &&
-        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO) ==
-            static_cast<int>(LevelType::CompressedNo) &&
-        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO) ==
-            static_cast<int>(LevelType::CompressedNuNo) &&
+            static_cast<int>(LevelFormat::Compressed) &&
         static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) ==
-            static_cast<int>(LevelType::Singleton) &&
-        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU) ==
-            static_cast<int>(LevelType::SingletonNu) &&
-        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO) ==
-            static_cast<int>(LevelType::SingletonNo) &&
-        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO) ==
-            static_cast<int>(LevelType::SingletonNuNo) &&
+            static_cast<int>(LevelFormat::Singleton) &&
         static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED) ==
-            static_cast<int>(LevelType::LooseCompressed) &&
-        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU) ==
-            static_cast<int>(LevelType::LooseCompressedNu) &&
-        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO) ==
-            static_cast<int>(LevelType::LooseCompressedNo) &&
-        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO) ==
-            static_cast<int>(LevelType::LooseCompressedNuNo) &&
+            static_cast<int>(LevelFormat::LooseCompressed) &&
         static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M) ==
-            static_cast<int>(LevelType::NOutOfM),
-    "MlirSparseTensorLevelType (C-API) and LevelType (C++) mismatch");
+            static_cast<int>(LevelFormat::NOutOfM),
+    "MlirSparseTensorLevelFormat (C-API) and LevelFormat (C++) mismatch");
+
+static_assert(static_cast<int>(MLIR_SPARSE_PROPERTY_NON_ORDERED) ==
+                      static_cast<int>(LevelPropertyNondefault::Nonordered) &&
+                  static_cast<int>(MLIR_SPARSE_PROPERTY_NON_UNIQUE) ==
+                      static_cast<int>(LevelPropertyNondefault::Nonunique),
+              "MlirSparseTensorLevelProperty (C-API) and "
+              "LevelPropertyNondefault (C++) mismatch");
 
 bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
   return isa<SparseTensorEncodingAttr>(unwrap(attr));
@@ -87,6 +76,13 @@ mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl) {
       cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlType(lvl));
 }
 
+enum MlirSparseTensorLevelFormat
+mlirSparseTensorEncodingAttrGetLvlFmt(MlirAttribute attr, intptr_t lvl) {
+  LevelType lt =
+      static_cast<LevelType>(mlirSparseTensorEncodingAttrGetLvlType(attr, lvl));
+  return static_cast<MlirSparseTensorLevelFormat>(*getLevelFormat(lt));
+}
+
 int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) {
   return cast<SparseTensorEncodingAttr>(unwrap(attr)).getPosWidth();
 }
@@ -95,12 +91,17 @@ int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr) {
   return cast<SparseTensorEncodingAttr>(unwrap(attr)).getCrdWidth();
 }
 
-MlirSparseTensorLevelType
-mlirSparseTensorEncodingAttrBuildLvlType(MlirBaseSparseTensorLevelType lvlType,
-                                         unsigned n, unsigned m) {
-  LevelType lt = static_cast<LevelType>(lvlType);
-  return static_cast<MlirSparseTensorLevelType>(*buildLevelType(
-      *getLevelFormat(lt), isOrderedLT(lt), isUniqueLT(lt), n, m));
+MlirSparseTensorLevelType mlirSparseTensorEncodingAttrBuildLvlType(
+    enum MlirSparseTensorLevelFormat lvlFmt,
+    const enum MlirSparseTensorLevelPropertyNondefault *properties,
+    unsigned size, unsigned n, unsigned m) {
+
+  std::vector<LevelPropertyNondefault> props;
+  for (unsigned i = 0; i < size; i++)
+    props.push_back(static_cast<LevelPropertyNondefault>(properties[i]));
+
+  return static_cast<MlirSparseTensorLevelType>(
+      *buildLevelType(static_cast<LevelFormat>(lvlFmt), props, n, m));
 }
 
 unsigned

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py
index 199777c79ef838..e2050b98728f21 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py
@@ -139,12 +139,15 @@ def main():
         # search the full state space to reduce runtime of the test. It is
         # straightforward to adapt the code below to explore more combinations.
         # For these simple orderings, dim2lvl and lvl2dim are the same.
+        builder = st.EncodingAttr.build_level_type
+        fmt = st.LevelFormat
+        prop = st.LevelProperty
         levels = [
-            [st.LevelType.compressed_nu, st.LevelType.singleton],
-            [st.LevelType.dense, st.LevelType.dense],
-            [st.LevelType.dense, st.LevelType.compressed],
-            [st.LevelType.compressed, st.LevelType.dense],
-            [st.LevelType.compressed, st.LevelType.compressed],
+            [builder(fmt.compressed, [prop.non_unique]), builder(fmt.singleton)],
+            [builder(fmt.dense), builder(fmt.dense)],
+            [builder(fmt.dense), builder(fmt.compressed)],
+            [builder(fmt.compressed), builder(fmt.dense)],
+            [builder(fmt.compressed), builder(fmt.compressed)],
         ]
         orderings = [
             ir.AffineMap.get_permutation([0, 1]),

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py
index 0aa4f92a7bf4ef..e7354c24d619e0 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py
@@ -125,12 +125,15 @@ def main():
         vl = 1
         e = False
         opt = f"parallelization-strategy=none"
+        builder = st.EncodingAttr.build_level_type
+        fmt = st.LevelFormat
+        prop = st.LevelProperty
         levels = [
-            [st.LevelType.compressed_nu, st.LevelType.singleton],
-            [st.LevelType.dense, st.LevelType.dense],
-            [st.LevelType.dense, st.LevelType.compressed],
-            [st.LevelType.compressed, st.LevelType.dense],
-            [st.LevelType.compressed, st.LevelType.compressed],
+            [builder(fmt.compressed, [prop.non_unique]), builder(fmt.singleton)],
+            [builder(fmt.dense), builder(fmt.dense)],
+            [builder(fmt.dense), builder(fmt.compressed)],
+            [builder(fmt.compressed), builder(fmt.dense)],
+            [builder(fmt.compressed), builder(fmt.compressed)],
         ]
         orderings = [
             ir.AffineMap.get_permutation([0, 1]),

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
index d994e8d0a8a19d..7da05303c7e1e1 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
@@ -124,11 +124,14 @@ def main():
         # Loop over various sparse types (COO, CSR, DCSR, CSC, DCSC) with
         # regular and loose compression and various metadata bitwidths.
         # For these simple orderings, dim2lvl and lvl2dim are the same.
+        builder = st.EncodingAttr.build_level_type
+        fmt = st.LevelFormat
+        prop = st.LevelProperty
         levels = [
-            [st.LevelType.compressed_nu, st.LevelType.singleton],
-            [st.LevelType.dense, st.LevelType.compressed],
-            [st.LevelType.dense, st.LevelType.loose_compressed],
-            [st.LevelType.compressed, st.LevelType.compressed],
+            [builder(fmt.compressed, [prop.non_unique]), builder(fmt.singleton)],
+            [builder(fmt.dense), builder(fmt.compressed)],
+            [builder(fmt.dense), builder(fmt.loose_compressed)],
+            [builder(fmt.compressed), builder(fmt.compressed)],
         ]
         orderings = [
             (ir.AffineMap.get_permutation([0, 1]), 0),
@@ -149,10 +152,10 @@ def main():
 
         # Now do the same for BSR.
         level = [
-            st.LevelType.dense,
-            st.LevelType.compressed,
-            st.LevelType.dense,
-            st.LevelType.dense,
+            builder(fmt.dense),
+            builder(fmt.compressed),
+            builder(fmt.dense),
+            builder(fmt.dense),
         ]
         d0 = ir.AffineDimExpr.get(0)
         d1 = ir.AffineDimExpr.get(1)

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
index 2b79c1416562dc..ce3516e2edaf03 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
@@ -203,10 +203,10 @@ def main():
         shape = range(2, 3)
         rank = len(shape)
         # All combinations.
+        dense_lvl = st.EncodingAttr.build_level_type(st.LevelFormat.dense)
+        sparse_lvl = st.EncodingAttr.build_level_type(st.LevelFormat.compressed)
         levels = list(
-            itertools.product(
-                *itertools.repeat([st.LevelType.dense, st.LevelType.compressed], rank)
-            )
+            itertools.product(*itertools.repeat([dense_lvl, sparse_lvl], rank))
         )
         # All permutations.
         orderings = list(

diff  --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 1fa7030ca1be91..2c0603216ef2c2 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -73,8 +73,8 @@ def testEncodingAttrStructure():
 
         # CHECK: lvl_types: [65536, 65536, 4406637494272]
         print(f"lvl_types: {casted.lvl_types}")
-        # CHECK: lvl_types_enum: [<LevelType.dense: 65536>, <LevelType.dense: 65536>, <LevelType.n_out_of_m: 1048576>]
-        print(f"lvl_types_enum: {casted.lvl_types_enum}")
+        # CHECK: lvl_formats_enum: [<LevelFormat.dense: 65536>, <LevelFormat.dense: 65536>, <LevelFormat.n_out_of_m: 1048576>]
+        print(f"lvl_formats_enum: {casted.lvl_formats_enum}")
         # CHECK: structured_n: 2
         print(f"structured_n: {casted.structured_n}")
         # CHECK: structured_m: 4
@@ -96,7 +96,10 @@ def testEncodingAttrStructure():
         # CHECK: created_equal: False
         print(f"created_equal: {created == casted}")
 
-        built_2_4 = st.EncodingAttr.build_level_type(st.LevelType.n_out_of_m, 2, 4)
+        built_2_4 = st.EncodingAttr.build_level_type(
+            st.LevelFormat.n_out_of_m, [], 2, 4
+        )
+        built_dense = st.EncodingAttr.build_level_type(st.LevelFormat.dense)
         dim_to_lvl = AffineMap.get(
             2,
             0,
@@ -118,7 +121,7 @@ def testEncodingAttrStructure():
             ],
         )
         built = st.EncodingAttr.get(
-            [st.LevelType.dense, st.LevelType.dense, built_2_4],
+            [built_dense, built_dense, built_2_4],
             dim_to_lvl,
             lvl_to_dim,
             0,


        


More information about the Mlir-commits mailing list