[Mlir-commits] [mlir] [mlir][sparse] Populate lvlToDim (PR #68937)

Yinying Li llvmlistbot at llvm.org
Tue Oct 17 12:19:41 PDT 2023


https://github.com/yinying-lisa-li updated https://github.com/llvm/llvm-project/pull/68937

>From d62f683f83cd3376102d16f87ba0202ed3c2bef4 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 12 Oct 2023 22:53:36 +0000
Subject: [PATCH 1/5] [mlir][sparse] Populate lvlToDim

Updates:
1. Infer lvlToDim from dimToLvl
2. Add more tests for block sparsity
3. Finish TODOs related to lvlToDim, including adding lvlToDim to python binding

Verification of lvlToDim that user provides will be implemented in the next PR.
---
 mlir/include/mlir-c/Dialect/SparseTensor.h    |  3 +-
 .../Dialect/SparseTensor/IR/SparseTensor.h    |  7 ++
 .../Bindings/Python/DialectSparseTensor.cpp   | 19 +++--
 mlir/lib/CAPI/Dialect/SparseTensor.cpp        | 10 +--
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 72 +++++++++++++++++--
 mlir/test/CAPI/sparse_tensor.c                |  5 +-
 .../SparseTensor/roundtrip_encoding.mlir      | 52 ++++++++++++++
 .../Dialect/SparseTensor/python/test_SDDMM.py |  2 +-
 .../Dialect/SparseTensor/python/test_SpMM.py  |  2 +-
 .../SparseTensor/python/test_output.py        |  3 +-
 .../SparseTensor/python/test_stress.py        |  2 +-
 .../python/dialects/sparse_tensor/dialect.py  | 10 ++-
 12 files changed, 164 insertions(+), 23 deletions(-)

diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 7e47e54e7361d54..859a4f0dd9f52c8 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -51,11 +51,10 @@ MLIR_CAPI_EXPORTED bool
 mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr);
 
 /// Creates a `sparse_tensor.encoding` attribute with the given parameters.
-/// TODO: add a version that supplied lvlToDim when it cannot be inferred
 MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet(
     MlirContext ctx, intptr_t lvlRank,
     enum MlirSparseTensorDimLevelType const *lvlTypes, MlirAffineMap dimToLvl,
-    int posWidth, int crdWidth);
+    MlirAffineMap lvlTodim, int posWidth, int crdWidth);
 
 /// Returns the level-rank of the `sparse_tensor.encoding` attribute.
 MLIR_CAPI_EXPORTED intptr_t
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index cbca0a7f8cc0e3a..81cce3a2b807bed 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -160,6 +160,13 @@ inline bool hasAnySparseOperandOrResult(Operation *op) {
   return hasAnySparseOperand(op) || hasAnySparseResult(op);
 }
 
+//
+// Inference.
+//
+
+AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context);
+AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context);
+
 //
 // Reordering.
 //
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index 8e9e0b6baf76c20..ba449faa9fe4262 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -41,17 +41,18 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
       .def_classmethod(
           "get",
           [](py::object cls, std::vector<MlirSparseTensorDimLevelType> lvlTypes,
-             std::optional<MlirAffineMap> dimToLvl, int posWidth, int crdWidth,
+             std::optional<MlirAffineMap> dimToLvl,
+             std::optional<MlirAffineMap> lvlToDim, int posWidth, int crdWidth,
              MlirContext context) {
-            // TODO: provide dimToLvl
             return cls(mlirSparseTensorEncodingAttrGet(
                 context, lvlTypes.size(), lvlTypes.data(),
-                dimToLvl ? *dimToLvl : MlirAffineMap{nullptr}, posWidth,
+                dimToLvl ? *dimToLvl : MlirAffineMap{nullptr},
+                lvlToDim ? *lvlToDim : MlirAffineMap{nullptr}, posWidth,
                 crdWidth));
           },
           py::arg("cls"), py::arg("lvl_types"), py::arg("dim_to_lvl"),
-          py::arg("pos_width"), py::arg("crd_width"),
-          py::arg("context") = py::none(),
+          py::arg("lvl_to_dim"), py::arg("pos_width"),
+          py::arg("crd_width"), py::arg("context") = py::none(),
           "Gets a sparse_tensor.encoding from parameters.")
       .def_property_readonly(
           "lvl_types",
@@ -71,6 +72,14 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
               return {};
             return ret;
           })
+      .def_property_readonly(
+          "lvl_to_dim",
+          [](MlirAttribute self) -> std::optional<MlirAffineMap> {
+            MlirAffineMap ret = mlirSparseTensorEncodingAttrGetLvlToDim(self);
+            if (mlirAffineMapIsNull(ret))
+              return {};
+            return ret;
+          })
       .def_property_readonly("pos_width",
                              mlirSparseTensorEncodingAttrGetPosWidth)
       .def_property_readonly("crd_width",
diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
index bf3a4ad5e7a1683..309d5ff5fedb90e 100644
--- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp
+++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
@@ -48,15 +48,17 @@ bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
 MlirAttribute
 mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank,
                                 MlirSparseTensorDimLevelType const *lvlTypes,
-                                MlirAffineMap dimToLvl, int posWidth,
-                                int crdWidth) {
+                                MlirAffineMap dimToLvl, MlirAffineMap lvlToDim,
+                                int posWidth, int crdWidth) {
   SmallVector<DimLevelType> cppLvlTypes;
   cppLvlTypes.reserve(lvlRank);
   for (intptr_t l = 0; l < lvlRank; ++l)
     cppLvlTypes.push_back(static_cast<DimLevelType>(lvlTypes[l]));
-  mlir::AffineMap lvlToDim; // TODO: provide in API
+  auto unwrappedLvlToDim = unwrap(lvlToDim);
+  if (!unwrappedLvlToDim)
+    unwrappedLvlToDim = inferLvlToDim(unwrap(dimToLvl), unwrap(ctx));
   return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppLvlTypes,
-                                            unwrap(dimToLvl), lvlToDim,
+                                            unwrap(dimToLvl), unwrappedLvlToDim,
                                             posWidth, crdWidth));
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index cd1e585438ddac9..22f20dc7b25101b 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -293,9 +293,8 @@ Type SparseTensorEncodingAttr::getCrdType() const {
 SparseTensorEncodingAttr
 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
-  // TODO: infer lvlToDim
   return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), dimToLvl,
-                                       /*lvlToDim*/ AffineMap(), getPosWidth(),
+                                       getLvlToDim(), getPosWidth(),
                                        getCrdWidth());
 }
 
@@ -583,7 +582,7 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
 #undef RETURN_ON_FAIL
 
   // Construct struct-like storage for attribute.
-  AffineMap lvlToDim; // TODO: infer
+  AffineMap lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
   return parser.getChecked<SparseTensorEncodingAttr>(
       parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
       dimSlices);
@@ -749,6 +748,71 @@ mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
   return nullptr;
 }
 
+AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl,
+                                             MLIRContext *context) {
+  auto map = static_cast<AffineMap>(dimToLvl);
+  AffineMap lvlToDim;
+  // TODO: support ELL instead of returning an empty lvlToDim.
+  if (!map || map.getNumSymbols() != 0) {
+    lvlToDim = AffineMap();
+  } else if (map.isPermutation()) {
+    lvlToDim = inversePermutation(map);
+  } else {
+    lvlToDim = inverseBlockSparsity(map, context);
+  }
+  return lvlToDim;
+}
+
+AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl,
+                                                    MLIRContext *context) {
+  SmallVector<AffineExpr> lvlExprs;
+  auto numLvls = dimToLvl.getNumResults();
+  lvlExprs.reserve(numLvls);
+  // lvlExprComponents stores information of the floordiv and mod operations
+  // applied to the same dimension, so as to build the lvlToDim map.
+  // Map key is the position of the dimension in dimToLvl.
+  // Map value is a SmallVector that contains lvl var for floordiv, multiplier,
+  // lvl var for mod in dimToLvl.
+  // For example, for il = i floordiv 2 and ii = i mod 2, the SmalleVector
+  // would be [il, 2, ii]. It could be used to build the AffineExpr
+  // i = il * 2 + ii in lvlToDim.
+  std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
+  for (unsigned i = 0, n = numLvls; i < n; i++) {
+    auto result = dimToLvl.getResult(i);
+    if (auto binOp = result.dyn_cast<AffineBinaryOpExpr>()) {
+      if (result.getKind() == AffineExprKind::FloorDiv) {
+        SmallVector<AffineExpr, 3> components;
+        // Level variable for floordiv.
+        components.push_back(getAffineDimExpr(i, context));
+        // Multiplier.
+        components.push_back(binOp.getRHS());
+        auto pos = binOp.getLHS().dyn_cast<AffineDimExpr>().getPosition();
+        lvlExprComponents[pos] = components;
+      } else if (result.getKind() == AffineExprKind::Mod) {
+        auto pos = binOp.getLHS().dyn_cast<AffineDimExpr>().getPosition();
+        assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
+               "expected floordiv before mod");
+        // Level variable for mod.
+        lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
+      } else {
+        assert(false && "expected floordiv or mod");
+      }
+    } else {
+      lvlExprs.push_back(getAffineDimExpr(i, context));
+    }
+  }
+  for (auto &components : lvlExprComponents) {
+    assert(components.second.size() == 3 &&
+           "expected 3 components to build lvlExprs");
+    auto mulOp = getAffineBinaryOpExpr(
+        AffineExprKind::Mul, components.second[0], components.second[1]);
+    auto addOp =
+        getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]);
+    lvlExprs.push_back(addOp);
+  }
+  return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context);
+}
+
 bool mlir::sparse_tensor::isCOOType(SparseTensorEncodingAttr enc,
                                     Level startLvl, bool isUnique) {
   if (!enc ||
@@ -811,7 +875,7 @@ RankedTensorType sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType rtt,
   // default value.
   unsigned posWidth = src.getPosWidth();
   unsigned crdWidth = src.getCrdWidth();
-  AffineMap invPerm; // TODO
+  auto invPerm = src.getLvlToDim();
   auto enc = SparseTensorEncodingAttr::get(src.getContext(), lvlTypes, lvlPerm,
                                            invPerm, posWidth, crdWidth);
   return RankedTensorType::get(src.getDimShape(), src.getElementType(), enc);
diff --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c
index 33ee8e784096a18..3bd1508cf299a3d 100644
--- a/mlir/test/CAPI/sparse_tensor.c
+++ b/mlir/test/CAPI/sparse_tensor.c
@@ -40,6 +40,8 @@ static int testRoundtripEncoding(MlirContext ctx) {
   // CHECK: level_type: 4
   // CHECK: level_type: 8
   // CHECK: level_type: 8
+  MlirAffineMap lvlToDim =
+      mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr);
   int lvlRank = mlirSparseTensorEncodingGetLvlRank(originalAttr);
   enum MlirSparseTensorDimLevelType *lvlTypes =
       malloc(sizeof(enum MlirSparseTensorDimLevelType) * lvlRank);
@@ -53,9 +55,8 @@ static int testRoundtripEncoding(MlirContext ctx) {
   // CHECK: crdWidth: 64
   int crdWidth = mlirSparseTensorEncodingAttrGetCrdWidth(originalAttr);
   fprintf(stderr, "crdWidth: %d\n", crdWidth);
-  // TODO: lvlToDim
   MlirAttribute newAttr = mlirSparseTensorEncodingAttrGet(
-      ctx, lvlRank, lvlTypes, dimToLvl, posWidth, crdWidth);
+      ctx, lvlRank, lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth);
   mlirAttributeDump(newAttr); // For debugging filecheck output.
   // CHECK: equal: 1
   fprintf(stderr, "equal: %d\n", mlirAttributeEqual(originalAttr, newAttr));
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index ae3805d8b774176..ea8217ab6e3f233 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -160,6 +160,24 @@ func.func private @BSR(%arg0: tensor<?x?xf64, #BSR>) {
 
 // -----
 
+#BCSR = #sparse_tensor.encoding<{
+  map = ( i, j, k ) ->
+  ( i floordiv 2 : dense,
+    j floordiv 3 : dense,
+    k floordiv 4 : compressed,
+    i mod 2      : dense,
+    j mod 3      : dense,
+    k mod 4      : dense
+  )
+}>
+
+// CHECK-LABEL: func private @BCSR(
+// CHECK-SAME: tensor<?x?x?xf64, #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 floordiv 2 : dense, d1 floordiv 3 : dense, d2 floordiv 4 : compressed, d0 mod 2 : dense, d1 mod 3 : dense, d2 mod 4 : dense) }>>
+func.func private @BCSR(%arg0: tensor<?x?x?xf64, #BCSR>) {
+  return
+}
+// -----
+
 #BSR_explicit = #sparse_tensor.encoding<{
   map =
   {il, jl, ii, jj}
@@ -194,3 +212,37 @@ func.func private @BSR_explicit(%arg0: tensor<?x?xf64, #BSR_explicit>) {
 func.func private @NV_24(%arg0: tensor<?x?xf64, #NV_24>) {
   return
 }
+
+// -----
+
+#NV_24 = #sparse_tensor.encoding<{
+  map = ( i, j, k ) ->
+  ( i            : dense,
+    j            : dense,
+    k floordiv 4 : dense,
+    k mod 4      : block2_4
+  )
+}>
+
+// CHECK-LABEL: func private @NV_24(
+// CHECK-SAME: tensor<?x?x?xf64, #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 floordiv 4 : dense, d2 mod 4 : block2_4) }>>
+func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
+  return
+}
+
+// -----
+
+#NV_24 = #sparse_tensor.encoding<{
+  map = ( i, j, k ) ->
+  ( i            : dense,
+    k floordiv 4 : dense,
+    j            : dense,
+    k mod 4      : block2_4
+  )
+}>
+
+// CHECK-LABEL: func private @NV_24(
+// CHECK-SAME: tensor<?x?x?xf64, #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d2 floordiv 4 : dense, d1 : dense, d2 mod 4 : block2_4) }>>
+func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
+  return
+}
\ No newline at end of file
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py
index 0cdc7c88bd97fb8..1f9b6360383180c 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py
@@ -155,7 +155,7 @@ def main():
                     for iwidth in [32]:
                         for e in [True]:
                             attr = st.EncodingAttr.get(
-                                level, ordering, pwidth, iwidth
+                                level, ordering, None, pwidth, iwidth
                             )
                             opt = f"parallelization-strategy=none"
                             compiler = sparse_compiler.SparseCompiler(
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py
index 01d74a4dc82fa1d..69f6cdcea967fae 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py
@@ -145,7 +145,7 @@ def main():
                 for pwidth in bitwidths:
                     for iwidth in bitwidths:
                         attr = st.EncodingAttr.get(
-                            level, ordering, pwidth, iwidth
+                            level, ordering, None, pwidth, iwidth
                         )
                         build_compile_and_run_SpMM(attr, compiler)
                         count = count + 1
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
index 8f3f4e5af1e58ef..b8ef614e04dfad2 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
@@ -91,7 +91,8 @@ def main():
         for level in levels:
             for ordering in orderings:
                 for bwidth in bitwidths:
-                    attr = st.EncodingAttr.get(level, ordering, bwidth, bwidth)
+                    attr = st.EncodingAttr.get(level, ordering, None,
+                                               bwidth, bwidth)
                     build_compile_and_run_output(attr, compiler)
                     count = count + 1
 
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
index ef266672ce42afc..841b02bc10c8bec 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
@@ -227,7 +227,7 @@ def main():
                 for pwidth in bitwidths:
                     for iwidth in bitwidths:
                         attr = st.EncodingAttr.get(
-                            level, ordering, pwidth, iwidth
+                            level, ordering, None, pwidth, iwidth
                         )
                         types.append(ir.RankedTensorType.get(shape, f64, attr))
         #
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index d80b878323377a4..10b390273cc91f3 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -32,12 +32,14 @@ def testEncodingAttr1D():
         print(f"lvl_types: {casted.lvl_types}")
         # CHECK: dim_to_lvl: None
         print(f"dim_to_lvl: {casted.dim_to_lvl}")
+        # CHECK: lvl_to_dim: None
+        print(f"lvl_to_dim: {casted.lvl_to_dim}")
         # CHECK: pos_width: 16
         print(f"pos_width: {casted.pos_width}")
         # CHECK: crd_width: 32
         print(f"crd_width: {casted.crd_width}")
 
-        created = st.EncodingAttr.get(casted.lvl_types, None, 0, 0)
+        created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0)
         # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
         print(created)
         # CHECK: created_equal: False
@@ -72,12 +74,16 @@ def testEncodingAttr2D():
         print(f"lvl_types: {casted.lvl_types}")
         # CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
         print(f"dim_to_lvl: {casted.dim_to_lvl}")
+        # CHECK: lvl_to_dim: (d0, d1) -> (d1, d0)
+        print(f"lvl_to_dim: {casted.lvl_to_dim}")
         # CHECK: pos_width: 8
         print(f"pos_width: {casted.pos_width}")
         # CHECK: crd_width: 32
         print(f"crd_width: {casted.crd_width}")
 
-        created = st.EncodingAttr.get(casted.lvl_types, casted.dim_to_lvl, 8, 32)
+        created = st.EncodingAttr.get(
+            casted.lvl_types, casted.dim_to_lvl, casted.lvl_to_dim, 8, 32,
+        )
         # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
         print(created)
         # CHECK: created_equal: True

>From a957b09f24dcd527718fd1c62af63e229829860e Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 12 Oct 2023 23:34:00 +0000
Subject: [PATCH 2/5] format

---
 mlir/lib/Bindings/Python/DialectSparseTensor.cpp            | 4 ++--
 .../Integration/Dialect/SparseTensor/python/test_output.py  | 3 +--
 mlir/test/python/dialects/sparse_tensor/dialect.py          | 6 +++++-
 3 files changed, 8 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index ba449faa9fe4262..9bde3a443ecfeca 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -51,8 +51,8 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
                 crdWidth));
           },
           py::arg("cls"), py::arg("lvl_types"), py::arg("dim_to_lvl"),
-          py::arg("lvl_to_dim"), py::arg("pos_width"),
-          py::arg("crd_width"), py::arg("context") = py::none(),
+          py::arg("lvl_to_dim"), py::arg("pos_width"), py::arg("crd_width"),
+          py::arg("context") = py::none(),
           "Gets a sparse_tensor.encoding from parameters.")
       .def_property_readonly(
           "lvl_types",
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
index b8ef614e04dfad2..7d7749008020515 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
@@ -91,8 +91,7 @@ def main():
         for level in levels:
             for ordering in orderings:
                 for bwidth in bitwidths:
-                    attr = st.EncodingAttr.get(level, ordering, None,
-                                               bwidth, bwidth)
+                    attr = st.EncodingAttr.get(level, ordering, None, bwidth, bwidth)
                     build_compile_and_run_output(attr, compiler)
                     count = count + 1
 
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 10b390273cc91f3..240db6ebd1d1eb3 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -82,7 +82,11 @@ def testEncodingAttr2D():
         print(f"crd_width: {casted.crd_width}")
 
         created = st.EncodingAttr.get(
-            casted.lvl_types, casted.dim_to_lvl, casted.lvl_to_dim, 8, 32,
+            casted.lvl_types,
+            casted.dim_to_lvl,
+            casted.lvl_to_dim,
+            8,
+            32,
         )
         # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
         print(created)

>From f2efff38910184f6bab297776b67576d035404ce Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Mon, 16 Oct 2023 19:54:33 +0000
Subject: [PATCH 3/5] address comments

---
 .../SparseTensor/IR/SparseTensorAttrDefs.td   |  3 +++
 mlir/lib/CAPI/Dialect/SparseTensor.cpp        |  5 +----
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 22 +++++++++++--------
 3 files changed, 17 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 38c7200afb41ffc..47fd18a689d5a8d 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -307,6 +307,9 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
                      "AffineMap":$lvlToDim,
                      "unsigned":$posWidth,
                      "unsigned":$crdWidth), [{
+      if (!lvlToDim) {
+        lvlToDim = ::mlir::sparse_tensor::inferLvlToDim(dimToLvl, $_ctxt);
+      }
       return $_get($_ctxt, lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
         ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr>{});
     }]>
diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
index 309d5ff5fedb90e..c3ad95527df489f 100644
--- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp
+++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
@@ -54,11 +54,8 @@ mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank,
   cppLvlTypes.reserve(lvlRank);
   for (intptr_t l = 0; l < lvlRank; ++l)
     cppLvlTypes.push_back(static_cast<DimLevelType>(lvlTypes[l]));
-  auto unwrappedLvlToDim = unwrap(lvlToDim);
-  if (!unwrappedLvlToDim)
-    unwrappedLvlToDim = inferLvlToDim(unwrap(dimToLvl), unwrap(ctx));
   return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppLvlTypes,
-                                            unwrap(dimToLvl), unwrappedLvlToDim,
+                                            unwrap(dimToLvl), unwrap(lvlToDim),
                                             posWidth, crdWidth));
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 22f20dc7b25101b..1d07e18e53a14ae 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -582,6 +582,7 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
 #undef RETURN_ON_FAIL
 
   // Construct struct-like storage for attribute.
+  // TODO: Fetch lvlToDim if user provides one
   AffineMap lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
   return parser.getChecked<SparseTensorEncodingAttr>(
       parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
@@ -770,29 +771,28 @@ AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl,
   lvlExprs.reserve(numLvls);
   // lvlExprComponents stores information of the floordiv and mod operations
   // applied to the same dimension, so as to build the lvlToDim map.
-  // Map key is the position of the dimension in dimToLvl.
-  // Map value is a SmallVector that contains lvl var for floordiv, multiplier,
-  // lvl var for mod in dimToLvl.
-  // For example, for il = i floordiv 2 and ii = i mod 2, the SmalleVector
-  // would be [il, 2, ii]. It could be used to build the AffineExpr
-  // i = il * 2 + ii in lvlToDim.
   std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
   for (unsigned i = 0, n = numLvls; i < n; i++) {
     auto result = dimToLvl.getResult(i);
     if (auto binOp = result.dyn_cast<AffineBinaryOpExpr>()) {
       if (result.getKind() == AffineExprKind::FloorDiv) {
+        // Position of the dimension in dimToLvl.
+        auto pos = binOp.getLHS().dyn_cast<AffineDimExpr>().getPosition();
+        assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
+               "expected only one floordiv for each dimension");
         SmallVector<AffineExpr, 3> components;
         // Level variable for floordiv.
         components.push_back(getAffineDimExpr(i, context));
         // Multiplier.
         components.push_back(binOp.getRHS());
-        auto pos = binOp.getLHS().dyn_cast<AffineDimExpr>().getPosition();
+        // Map key is the position of the dimension.
         lvlExprComponents[pos] = components;
       } else if (result.getKind() == AffineExprKind::Mod) {
         auto pos = binOp.getLHS().dyn_cast<AffineDimExpr>().getPosition();
         assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
                "expected floordiv before mod");
-        // Level variable for mod.
+        // Level variable for mod added to the vector of the corresponding
+        // floordiv with the same dimension.
         lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
       } else {
         assert(false && "expected floordiv or mod");
@@ -801,6 +801,10 @@ AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl,
       lvlExprs.push_back(getAffineDimExpr(i, context));
     }
   }
+  // Build lvlExprs from lvlExprComponents.
+  // For example, for il = i floordiv 2 and ii = i mod 2, the components
+  // would be [il, 2, ii]. It could be used to build the AffineExpr
+  // i = il * 2 + ii in lvlToDim.
   for (auto &components : lvlExprComponents) {
     assert(components.second.size() == 3 &&
            "expected 3 components to build lvlExprs");
@@ -875,7 +879,7 @@ RankedTensorType sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType rtt,
   // default value.
   unsigned posWidth = src.getPosWidth();
   unsigned crdWidth = src.getCrdWidth();
-  auto invPerm = src.getLvlToDim();
+  AffineMap invPerm = src.getLvlToDim();
   auto enc = SparseTensorEncodingAttr::get(src.getContext(), lvlTypes, lvlPerm,
                                            invPerm, posWidth, crdWidth);
   return RankedTensorType::get(src.getDimShape(), src.getElementType(), enc);

>From 8d28731aebd02e3abe713d4c188248ce13051e83 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Mon, 16 Oct 2023 21:12:53 +0000
Subject: [PATCH 4/5] modify comments

---
 mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 1d07e18e53a14ae..1495c4e23b0965d 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -753,7 +753,7 @@ AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl,
                                              MLIRContext *context) {
   auto map = static_cast<AffineMap>(dimToLvl);
   AffineMap lvlToDim;
-  // TODO: support ELL instead of returning an empty lvlToDim.
+  // Return an empty lvlToDim when inference is not successful.
   if (!map || map.getNumSymbols() != 0) {
     lvlToDim = AffineMap();
   } else if (map.isPermutation()) {
@@ -791,8 +791,8 @@ AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl,
         auto pos = binOp.getLHS().dyn_cast<AffineDimExpr>().getPosition();
         assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
                "expected floordiv before mod");
-        // Level variable for mod added to the vector of the corresponding
-        // floordiv with the same dimension.
+        // Add level variable for mod to the same vector
+        // of the corresponding floordiv.
         lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
       } else {
         assert(false && "expected floordiv or mod");

>From 808d87861d02769cce8755df87de02c40221adfa Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Tue, 17 Oct 2023 18:58:15 +0000
Subject: [PATCH 5/5] added docs

---
 mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h | 6 ++++++
 mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp | 1 +
 2 files changed, 7 insertions(+)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 81cce3a2b807bed..6e834426b441764 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -164,7 +164,13 @@ inline bool hasAnySparseOperandOrResult(Operation *op) {
 // Inference.
 //
 
+/// Given the dimToLvl map, infers the lvlToDim map, or returns
+/// empty Affine map when inference fails.
 AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context);
+
+/// Returns the lvlToDim map for the given dimToLvl map specific
+/// to the block sparse cases.
+/// Asserts on failure (so only use when known to succeed).
 AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context);
 
 //
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 1495c4e23b0965d..fd87bbfa905ed69 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -759,6 +759,7 @@ AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl,
   } else if (map.isPermutation()) {
     lvlToDim = inversePermutation(map);
   } else {
+    // TODO: check if it's block sparsity
     lvlToDim = inverseBlockSparsity(map, context);
   }
   return lvlToDim;



More information about the Mlir-commits mailing list