[Mlir-commits] [mlir] [mlir][sparse] Add more tests and verification for n:m (PR #81186)

Yinying Li llvmlistbot at llvm.org
Thu Feb 8 12:14:15 PST 2024


https://github.com/yinying-lisa-li created https://github.com/llvm/llvm-project/pull/81186

1. Add python test for n out of m
2. Add more methods for python binding
3. Add verification for n:m and invalid encoding tests

>From 61bc16760fb362b42051000346f05b7a1475ee6f Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Wed, 7 Feb 2024 20:10:32 +0000
Subject: [PATCH 1/3] pybinding for n:m

---
 mlir/include/mlir-c/Dialect/SparseTensor.h    | 11 ++++
 .../Bindings/Python/DialectSparseTensor.cpp   | 39 ++++++++++++-
 mlir/lib/CAPI/Dialect/SparseTensor.cpp        | 18 ++++++
 mlir/test/CAPI/sparse_tensor.c                |  4 +-
 .../python/dialects/sparse_tensor/dialect.py  | 57 +++++++++++++++++++
 5 files changed, 126 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 2c71b0008ad16a..d4cac0e326cffd 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -84,6 +84,17 @@ mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr);
 MLIR_CAPI_EXPORTED int
 mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr);
 
+MLIR_CAPI_EXPORTED unsigned
+mlirSparseTensorEncodingAttrGetStructuredN(MlirSparseTensorLevelType lvlType);
+
+MLIR_CAPI_EXPORTED unsigned
+mlirSparseTensorEncodingAttrGetStructuredM(MlirSparseTensorLevelType lvlType);
+
+MLIR_CAPI_EXPORTED MlirSparseTensorLevelType
+mlirSparseTensorEncodingAttrBuildLvlType(MlirBaseSparseTensorLevelType lvlType,
+                                         bool ordered, bool unique, unsigned n,
+                                         unsigned m);
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index 607534c6156439..ac2b6720c699b9 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -60,6 +60,16 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
           py::arg("lvl_to_dim"), py::arg("pos_width"), py::arg("crd_width"),
           py::arg("context") = py::none(),
           "Gets a sparse_tensor.encoding from parameters.")
+      .def_classmethod(
+          "build_level_type",
+          [](py::object cls, MlirBaseSparseTensorLevelType lvlType,
+             bool ordered, bool unique, unsigned n, unsigned m) {
+            return mlirSparseTensorEncodingAttrBuildLvlType(lvlType, ordered,
+                                                            unique, n, m);
+          },
+          py::arg("cls"), py::arg("lvl_type"), py::arg("ordered") = true,
+          py::arg("unique") = true, py::arg("n") = 0, py::arg("m") = 0,
+          "Builds a sparse_tensor.encoding.level_type from parameters.")
       .def_property_readonly(
           "lvl_types",
           [](MlirAttribute self) {
@@ -89,7 +99,34 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
       .def_property_readonly("pos_width",
                              mlirSparseTensorEncodingAttrGetPosWidth)
       .def_property_readonly("crd_width",
-                             mlirSparseTensorEncodingAttrGetCrdWidth);
+                             mlirSparseTensorEncodingAttrGetCrdWidth)
+      .def_property_readonly(
+          "structured_n",
+          [](MlirAttribute self) -> unsigned {
+            const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
+            return mlirSparseTensorEncodingAttrGetStructuredN(
+                mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1));
+          })
+      .def_property_readonly(
+          "structured_m",
+          [](MlirAttribute self) -> unsigned {
+            const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
+            return mlirSparseTensorEncodingAttrGetStructuredM(
+                mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1));
+          })
+      .def_property_readonly("lvl_types_enum", [](MlirAttribute self) {
+        const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
+        std::vector<MlirBaseSparseTensorLevelType> ret;
+        ret.reserve(lvlRank);
+        for (int l = 0; l < lvlRank; l++) {
+          // Convert level type to 32 bits to ignore n and m for n_out_of_m
+          // format.
+          ret.push_back(
+              static_cast<MlirBaseSparseTensorLevelType>(static_cast<uint32_t>(
+                  mlirSparseTensorEncodingAttrGetLvlType(self, l))));
+        }
+        return ret;
+      });
 }
 
 PYBIND11_MODULE(_mlirDialectsSparseTensor, m) {
diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
index a34b9a29b0e90a..2a7d47d9ece90f 100644
--- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp
+++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
@@ -94,3 +94,21 @@ int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) {
 int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr) {
   return cast<SparseTensorEncodingAttr>(unwrap(attr)).getCrdWidth();
 }
+
+MlirSparseTensorLevelType
+mlirSparseTensorEncodingAttrBuildLvlType(MlirBaseSparseTensorLevelType lvlType,
+                                         bool ordered, bool unique, unsigned n,
+                                         unsigned m) {
+  return static_cast<MlirSparseTensorLevelType>(*buildLevelType(
+      *getLevelFormat(static_cast<LevelType>(lvlType)), ordered, unique, n, m));
+}
+
+unsigned
+mlirSparseTensorEncodingAttrGetStructuredN(MlirSparseTensorLevelType lvlType) {
+  return getN(static_cast<LevelType>(lvlType));
+}
+
+unsigned
+mlirSparseTensorEncodingAttrGetStructuredM(MlirSparseTensorLevelType lvlType) {
+  return getM(static_cast<LevelType>(lvlType));
+}
diff --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c
index a8b9f9048d5912..5c21f0b1d8927d 100644
--- a/mlir/test/CAPI/sparse_tensor.c
+++ b/mlir/test/CAPI/sparse_tensor.c
@@ -26,7 +26,7 @@ static int testRoundtripEncoding(MlirContext ctx) {
   // clang-format off
   const char *originalAsm =
     "#sparse_tensor.encoding<{ "
-    "map = [s0](d0, d1) -> (s0 : dense, d0 : compressed, d1 : compressed), "
+    "map = [s0](d0, d1) -> (s0 : dense, d0 : compressed, d1 : structured[2, 4]), "
     "posWidth = 32, crdWidth = 64 }>";
   // clang-format on
   MlirAttribute originalAttr =
@@ -40,7 +40,7 @@ static int testRoundtripEncoding(MlirContext ctx) {
   mlirAffineMapDump(dimToLvl);
   // CHECK: level_type: 65536
   // CHECK: level_type: 131072
-  // CHECK: level_type: 131072
+  // CHECK: level_type: 4406637494272
   MlirAffineMap lvlToDim =
       mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr);
   int lvlRank = mlirSparseTensorEncodingGetLvlRank(originalAttr);
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 412c5797067b7a..ba7205d5162aa1 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -52,6 +52,63 @@ def testEncodingAttr1D():
         print(f"created_pos_width: {created.pos_width}")
 
 
+# CHECK-LABEL: TEST: testEncodingAttr1DStructure
+ at run
+def testEncodingAttr1DStructure():
+    with Context() as ctx:
+        parsed = Attribute.parse(
+            "#sparse_tensor.encoding<{"
+            "  map = (d0) -> (d0 : structured[2, 4]),"
+            "  posWidth = 16,"
+            "  crdWidth = 32"
+            "}>"
+        )
+        # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : structured[2, 4]), posWidth = 16, crdWidth = 32 }>
+        print(parsed)
+
+        casted = st.EncodingAttr(parsed)
+        # CHECK: equal: True
+        print(f"equal: {casted == parsed}")
+
+        # CHECK: lvl_types: [4406637494272]
+        print(f"lvl_types: {casted.lvl_types}")
+        # CHECK: lvl_types: [<LevelType.n_out_of_m: 1048576>]
+        print(f"lvl_types_enum: {casted.lvl_types_enum}")
+        # CHECK: structured_n: 2
+        print(f"structured_n: {casted.structured_n}")
+        # CHECK: structured_m: 4
+        print(f"structured_m: {casted.structured_m}")
+        # CHECK: dim_to_lvl: (d0) -> (d0)
+        print(f"dim_to_lvl: {casted.dim_to_lvl}")
+        # CHECK: lvl_to_dim: (d0) -> (d0)
+        print(f"lvl_to_dim: {casted.lvl_to_dim}")
+        # CHECK: pos_width: 16
+        print(f"pos_width: {casted.pos_width}")
+        # CHECK: crd_width: 32
+        print(f"crd_width: {casted.crd_width}")
+
+        created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0)
+        # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : structured[2, 4]) }>
+        print(created)
+        # CHECK: created_equal: False
+        print(f"created_equal: {created == casted}")
+
+        built_type = st.EncodingAttr.build_level_type(
+            st.LevelType.n_out_of_m, True, True, 2, 4
+        )
+        built = st.EncodingAttr.get([built_type], None, None, 0, 0)
+        # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : structured[2, 4]) }>
+        print(built)
+        # CHECK: built_equal: True
+        print(f"built_equal: {built == created}")
+
+        # Verify that the factory creates an instance of the proper type.
+        # CHECK: is_proper_instance: True
+        print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
+        # CHECK: created_pos_width: 0
+        print(f"created_pos_width: {created.pos_width}")
+
+
 # CHECK-LABEL: TEST: testEncodingAttr2D
 @run
 def testEncodingAttr2D():

>From 00b2bf5febd20f5b891335cedb95073527f7ad9d Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Wed, 7 Feb 2024 22:42:43 +0000
Subject: [PATCH 2/3] verification

---
 mlir/include/mlir-c/Dialect/SparseTensor.h    |  5 +-
 .../Bindings/Python/DialectSparseTensor.cpp   |  8 +--
 mlir/lib/CAPI/Dialect/SparseTensor.cpp        |  6 +-
 .../SparseTensor/IR/Detail/LvlTypeParser.cpp  | 28 +++++++--
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 10 ++++
 mlir/test/CAPI/sparse_tensor.c                |  4 +-
 .../SparseTensor/invalid_encoding.mlir        | 60 +++++++++++++++++++
 .../python/dialects/sparse_tensor/dialect.py  | 57 +++++++++++++-----
 8 files changed, 146 insertions(+), 32 deletions(-)

diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index d4cac0e326cffd..d549f5dddc1318 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -91,9 +91,8 @@ MLIR_CAPI_EXPORTED unsigned
 mlirSparseTensorEncodingAttrGetStructuredM(MlirSparseTensorLevelType lvlType);
 
 MLIR_CAPI_EXPORTED MlirSparseTensorLevelType
-mlirSparseTensorEncodingAttrBuildLvlType(MlirBaseSparseTensorLevelType lvlType,
-                                         bool ordered, bool unique, unsigned n,
-                                         unsigned m);
+mlirSparseTensorEncodingAttrBuildLvlType(
+    enum MlirBaseSparseTensorLevelType lvlType, unsigned n, unsigned m);
 
 #ifdef __cplusplus
 }
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index ac2b6720c699b9..fe6e5e4e5b2b26 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -62,13 +62,13 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
           "Gets a sparse_tensor.encoding from parameters.")
       .def_classmethod(
           "build_level_type",
-          [](py::object cls, MlirBaseSparseTensorLevelType lvlType,
-             bool ordered, bool unique, unsigned n, unsigned m) {
+          [](py::object cls, MlirBaseSparseTensorLevelType lvlType, unsigned n,
+             unsigned m) {
             return mlirSparseTensorEncodingAttrBuildLvlType(lvlType, ordered,
                                                             unique, n, m);
           },
-          py::arg("cls"), py::arg("lvl_type"), py::arg("ordered") = true,
-          py::arg("unique") = true, py::arg("n") = 0, py::arg("m") = 0,
+          py::arg("cls"), py::arg("lvl_type"), py::arg("n") = 0,
+          py::arg("m") = 0,
           "Builds a sparse_tensor.encoding.level_type from parameters.")
       .def_property_readonly(
           "lvl_types",
diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
index 2a7d47d9ece90f..4e1bd45863fdac 100644
--- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp
+++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
@@ -97,10 +97,10 @@ int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr) {
 
 MlirSparseTensorLevelType
 mlirSparseTensorEncodingAttrBuildLvlType(MlirBaseSparseTensorLevelType lvlType,
-                                         bool ordered, bool unique, unsigned n,
-                                         unsigned m) {
+                                         unsigned n, unsigned m) {
+  LevelType lt = static_cast<LevelType>(lvlType);
   return static_cast<MlirSparseTensorLevelType>(*buildLevelType(
-      *getLevelFormat(static_cast<LevelType>(lvlType)), ordered, unique, n, m));
+      *getLevelFormat(lt), isOrderedLT(lt), isUniqueLT(lt), n, m));
 }
 
 unsigned
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index 752d6e6481dfee..49189a82fb0704 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -45,6 +45,22 @@ FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
     FAILURE_IF_FAILED(res)
   }
 
+  if (base.compare("structured") == 0) {
+    ParseResult res = parser.parseCommaSeparatedList(
+        mlir::OpAsmParser::Delimiter::OptionalSquare,
+        [&]() -> ParseResult { return parseStructure(parser, &structure); },
+        " in structure n out of m");
+    FAILURE_IF_FAILED(res)
+    if (structure.size() != 2) {
+      parser.emitError(loc, "expected exactly 2 structure sizes");
+      return failure();
+    }
+    if (structure[0] > structure[1]) {
+      parser.emitError(loc, "expected n <= m in n_out_of_m");
+      return failure();
+    }
+  }
+
   ParseResult res = parser.parseCommaSeparatedList(
       mlir::OpAsmParser::Delimiter::OptionalParen,
       [&]() -> ParseResult { return parseProperty(parser, &properties); },
@@ -57,10 +73,6 @@ FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
   } else if (base.compare("compressed") == 0) {
     properties |= static_cast<uint64_t>(LevelFormat::Compressed);
   } else if (base.compare("structured") == 0) {
-    if (structure.size() != 2) {
-      parser.emitError(loc, "expected exactly 2 structure sizes");
-      return failure();
-    }
     properties |= static_cast<uint64_t>(LevelFormat::NOutOfM);
     properties |= nToBits(structure[0]) | mToBits(structure[1]);
   } else if (base.compare("loose_compressed") == 0) {
@@ -102,13 +114,17 @@ LvlTypeParser::parseStructure(AsmParser &parser,
   OptionalParseResult intValParseResult = parser.parseOptionalInteger(intVal);
   if (intValParseResult.has_value()) {
     if (failed(*intValParseResult)) {
-      parser.emitError(loc, "failed to parse block size");
+      parser.emitError(loc, "failed to parse structure size");
+      return failure();
+    }
+    if (intVal < 0) {
+      parser.emitError(loc, "expected structure size to be >= 0");
       return failure();
     }
     structure->push_back(intVal);
     return success();
   }
-  parser.emitError(loc, "expected valid integer for block size");
+  parser.emitError(loc, "expected valid integer for structure size");
   return failure();
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 67b1d7974fa259..c7a063f020794e 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -657,6 +657,16 @@ LogicalResult SparseTensorEncodingAttr::verify(
       return emitError() << "expected all singleton lvlTypes "
                             "following a singleton level";
   }
+  // TODO: audit formats that actually are supported by backend.
+  if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isNOutOfMLT);
+      it != std::end(lvlTypes)) {
+    if (it != lvlTypes.end() - 1)
+      return emitError() << "expected n_out_of_m to be the last level type";
+    if (!std::all_of(lvlTypes.begin(), it,
+                     [](LevelType i) { return isDenseLT(i); }))
+      return emitError() << "expected all dense lvlTypes "
+                            "before a n_out_of_m level";
+  }
   // Before we can check that the level-rank is consistent/coherent
   // across all fields, we need to define it.  The source-of-truth for
   // the `getLvlRank` method is the length of the level-types array,
diff --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c
index 5c21f0b1d8927d..580a619d1f7d99 100644
--- a/mlir/test/CAPI/sparse_tensor.c
+++ b/mlir/test/CAPI/sparse_tensor.c
@@ -26,7 +26,7 @@ static int testRoundtripEncoding(MlirContext ctx) {
   // clang-format off
   const char *originalAsm =
     "#sparse_tensor.encoding<{ "
-    "map = [s0](d0, d1) -> (s0 : dense, d0 : compressed, d1 : structured[2, 4]), "
+    "map = [s0](d0, d1) -> (s0 : dense, d0 : dense, d1 : structured[2, 4]), "
     "posWidth = 32, crdWidth = 64 }>";
   // clang-format on
   MlirAttribute originalAttr =
@@ -39,7 +39,7 @@ static int testRoundtripEncoding(MlirContext ctx) {
   // CHECK: (d0, d1)[s0] -> (s0, d0, d1)
   mlirAffineMapDump(dimToLvl);
   // CHECK: level_type: 65536
-  // CHECK: level_type: 131072
+  // CHECK: level_type: 65536
   // CHECK: level_type: 4406637494272
   MlirAffineMap lvlToDim =
       mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr);
diff --git a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
index 2d189cc94c15e2..49d6e7a6d866a9 100644
--- a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
@@ -315,3 +315,63 @@ func.func private @BSR(%arg0: tensor<?x?xf64, #BSR>) {
 func.func private @BSR_explicit(%arg0: tensor<?x?xf64, #BSR_explicit>) {
   return
 }
+
+// -----
+
+// expected-error at +6 {{expected structure size to be >= 0}}
+#NOutOfM = #sparse_tensor.encoding<{
+  map = ( i, j, k ) ->
+  ( i            : dense,
+    k floordiv 4 : dense,
+    j            : dense,
+    k mod 4      : structured[-2, 4]
+  )
+}>
+func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
+  return
+}
+
+// -----
+
+// expected-error at +6 {{expected n <= m in n_out_of_m}}
+#NOutOfM = #sparse_tensor.encoding<{
+  map = ( i, j, k ) ->
+  ( i            : dense,
+    k floordiv 4 : dense,
+    j            : dense,
+    k mod 4      : structured[5, 4]
+  )
+}>
+func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
+  return
+}
+
+// -----
+
+// expected-error at +1 {{expected all dense lvlTypes before a n_out_of_m level}}
+#NOutOfM = #sparse_tensor.encoding<{
+  map = ( i, j, k ) ->
+  ( i            : dense,
+    k floordiv 4 : compressed,
+    j            : dense,
+    k mod 4      : structured[2, 4]
+  )
+}>
+func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
+  return
+}
+
+// -----
+
+// expected-error at +1 {{expected n_out_of_m to be the last level type}}
+#NOutOfM = #sparse_tensor.encoding<{
+  map = ( i, j, k ) ->
+  ( i            : dense,
+    k floordiv 4 : structured[2, 4],
+    j            : dense,
+    k mod 4      : compressed
+  )
+}>
+func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
+  return
+}
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index ba7205d5162aa1..b066ba00dddae5 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -52,52 +52,81 @@ def testEncodingAttr1D():
         print(f"created_pos_width: {created.pos_width}")
 
 
-# CHECK-LABEL: TEST: testEncodingAttr1DStructure
+# CHECK-LABEL: TEST: testEncodingAttrStructure
 @run
-def testEncodingAttr1DStructure():
+def testEncodingAttrStructure():
     with Context() as ctx:
         parsed = Attribute.parse(
             "#sparse_tensor.encoding<{"
-            "  map = (d0) -> (d0 : structured[2, 4]),"
+            "  map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense,"
+            "  d1 mod 4 : structured[2, 4]),"
             "  posWidth = 16,"
             "  crdWidth = 32"
             "}>"
         )
-        # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : structured[2, 4]), posWidth = 16, crdWidth = 32 }>
+        # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : structured[2, 4]), posWidth = 16, crdWidth = 32 }>
         print(parsed)
 
         casted = st.EncodingAttr(parsed)
         # CHECK: equal: True
         print(f"equal: {casted == parsed}")
 
-        # CHECK: lvl_types: [4406637494272]
+        # CHECK: lvl_types: [65536, 65536, 4406637494272]
         print(f"lvl_types: {casted.lvl_types}")
-        # CHECK: lvl_types: [<LevelType.n_out_of_m: 1048576>]
+        # CHECK: lvl_types_enum: [<LevelType.dense: 65536>, <LevelType.dense: 65536>, <LevelType.n_out_of_m: 1048576>]
         print(f"lvl_types_enum: {casted.lvl_types_enum}")
         # CHECK: structured_n: 2
         print(f"structured_n: {casted.structured_n}")
         # CHECK: structured_m: 4
         print(f"structured_m: {casted.structured_m}")
-        # CHECK: dim_to_lvl: (d0) -> (d0)
+        # CHECK: dim_to_lvl: (d0, d1) -> (d0, d1 floordiv 4, d1 mod 4)
         print(f"dim_to_lvl: {casted.dim_to_lvl}")
-        # CHECK: lvl_to_dim: (d0) -> (d0)
+        # CHECK: lvl_to_dim: (d0, d1, d2) -> (d0, d1 * 4 + d2)
         print(f"lvl_to_dim: {casted.lvl_to_dim}")
         # CHECK: pos_width: 16
         print(f"pos_width: {casted.pos_width}")
         # CHECK: crd_width: 32
         print(f"crd_width: {casted.crd_width}")
 
-        created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0)
-        # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : structured[2, 4]) }>
+        created = st.EncodingAttr.get(
+            casted.lvl_types, casted.dim_to_lvl, casted.lvl_to_dim, 0, 0
+        )
+        # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : structured[2, 4]) }>
         print(created)
         # CHECK: created_equal: False
         print(f"created_equal: {created == casted}")
 
-        built_type = st.EncodingAttr.build_level_type(
-            st.LevelType.n_out_of_m, True, True, 2, 4
+        built_2_4 = st.EncodingAttr.build_level_type(
+            st.LevelType.n_out_of_m, 2, 4
+        )
+        dim_to_lvl = AffineMap.get(
+            2,
+            0,
+            [
+                AffineExpr.get_dim(0),
+                AffineExpr.get_floor_div(AffineExpr.get_dim(1), 4),
+                AffineExpr.get_mod(AffineExpr.get_dim(1), 4),
+            ],
+        )
+        lvl_to_dim = AffineMap.get(
+            3,
+            0,
+            [
+                AffineExpr.get_dim(0),
+                AffineExpr.get_add(
+                    AffineExpr.get_mul(AffineExpr.get_dim(1), 4),
+                    AffineExpr.get_dim(2),
+                ),
+            ],
+        )
+        built = st.EncodingAttr.get(
+            [st.LevelType.dense, st.LevelType.dense, built_2_4],
+            dim_to_lvl,
+            lvl_to_dim,
+            0,
+            0,
         )
-        built = st.EncodingAttr.get([built_type], None, None, 0, 0)
-        # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : structured[2, 4]) }>
+        # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : structured[2, 4]) }>
         print(built)
         # CHECK: built_equal: True
         print(f"built_equal: {built == created}")

>From bab45ddc09e96982dc27f2edd0262812e3398c17 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 8 Feb 2024 20:11:39 +0000
Subject: [PATCH 3/3] remove redundant code

---
 mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp | 8 --------
 1 file changed, 8 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index 49189a82fb0704..a585928c3fa3ee 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -37,14 +37,6 @@ FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
   uint64_t properties = 0;
   SmallVector<unsigned> structure;
 
-  if (base.compare("structured") == 0) {
-    ParseResult res = parser.parseCommaSeparatedList(
-        mlir::OpAsmParser::Delimiter::OptionalSquare,
-        [&]() -> ParseResult { return parseStructure(parser, &structure); },
-        " in block n out of m");
-    FAILURE_IF_FAILED(res)
-  }
-
   if (base.compare("structured") == 0) {
     ParseResult res = parser.parseCommaSeparatedList(
         mlir::OpAsmParser::Delimiter::OptionalSquare,



More information about the Mlir-commits mailing list