[Mlir-commits] [mlir] [mlir][sparse] Populate lvlToDim (PR #68937)

Yinying Li llvmlistbot at llvm.org
Mon Oct 16 14:10:22 PDT 2023


https://github.com/yinying-lisa-li updated https://github.com/llvm/llvm-project/pull/68937

>From 52c5001c775bcb44042124d1afcb1db09fb59957 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 12 Oct 2023 22:53:36 +0000
Subject: [PATCH 1/3] [mlir][sparse] Populate lvlToDim

Updates:
1. Infer lvlToDim from dimToLvl
2. Add more tests for block sparsity
3. Finish TODOs related to lvlToDim, including adding lvlToDim to python binding

Verification of lvlToDim that user provides will be implemented in the next PR.
---
 mlir/include/mlir-c/Dialect/SparseTensor.h    |  3 +-
 .../Dialect/SparseTensor/IR/SparseTensor.h    |  7 ++
 .../Bindings/Python/DialectSparseTensor.cpp   | 19 +++--
 mlir/lib/CAPI/Dialect/SparseTensor.cpp        | 10 +--
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 72 +++++++++++++++++--
 mlir/test/CAPI/sparse_tensor.c                |  5 +-
 .../SparseTensor/roundtrip_encoding.mlir      | 52 ++++++++++++++
 .../Dialect/SparseTensor/python/test_SDDMM.py |  2 +-
 .../Dialect/SparseTensor/python/test_SpMM.py  |  2 +-
 .../SparseTensor/python/test_output.py        |  3 +-
 .../SparseTensor/python/test_stress.py        |  2 +-
 .../python/dialects/sparse_tensor/dialect.py  | 10 ++-
 12 files changed, 164 insertions(+), 23 deletions(-)

diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 7e47e54e7361d54..859a4f0dd9f52c8 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -51,11 +51,10 @@ MLIR_CAPI_EXPORTED bool
 mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr);
 
 /// Creates a `sparse_tensor.encoding` attribute with the given parameters.
-/// TODO: add a version that supplied lvlToDim when it cannot be inferred
 MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet(
     MlirContext ctx, intptr_t lvlRank,
     enum MlirSparseTensorDimLevelType const *lvlTypes, MlirAffineMap dimToLvl,
-    int posWidth, int crdWidth);
+    MlirAffineMap lvlTodim, int posWidth, int crdWidth);
 
 /// Returns the level-rank of the `sparse_tensor.encoding` attribute.
 MLIR_CAPI_EXPORTED intptr_t
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 3eb9ce010cb006f..8cbedc560089f7d 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -159,6 +159,13 @@ inline bool hasAnySparseOperandOrResult(Operation *op) {
   return hasAnySparseOperand(op) || hasAnySparseResult(op);
 }
 
+//
+// Inference.
+//
+
+AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context);
+AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context);
+
 //
 // Reordering.
 //
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index 8e9e0b6baf76c20..ba449faa9fe4262 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -41,17 +41,18 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
       .def_classmethod(
           "get",
           [](py::object cls, std::vector<MlirSparseTensorDimLevelType> lvlTypes,
-             std::optional<MlirAffineMap> dimToLvl, int posWidth, int crdWidth,
+             std::optional<MlirAffineMap> dimToLvl,
+             std::optional<MlirAffineMap> lvlToDim, int posWidth, int crdWidth,
              MlirContext context) {
-            // TODO: provide dimToLvl
             return cls(mlirSparseTensorEncodingAttrGet(
                 context, lvlTypes.size(), lvlTypes.data(),
-                dimToLvl ? *dimToLvl : MlirAffineMap{nullptr}, posWidth,
+                dimToLvl ? *dimToLvl : MlirAffineMap{nullptr},
+                lvlToDim ? *lvlToDim : MlirAffineMap{nullptr}, posWidth,
                 crdWidth));
           },
           py::arg("cls"), py::arg("lvl_types"), py::arg("dim_to_lvl"),
-          py::arg("pos_width"), py::arg("crd_width"),
-          py::arg("context") = py::none(),
+          py::arg("lvl_to_dim"), py::arg("pos_width"),
+          py::arg("crd_width"), py::arg("context") = py::none(),
           "Gets a sparse_tensor.encoding from parameters.")
       .def_property_readonly(
           "lvl_types",
@@ -71,6 +72,14 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
               return {};
             return ret;
           })
+      .def_property_readonly(
+          "lvl_to_dim",
+          [](MlirAttribute self) -> std::optional<MlirAffineMap> {
+            MlirAffineMap ret = mlirSparseTensorEncodingAttrGetLvlToDim(self);
+            if (mlirAffineMapIsNull(ret))
+              return {};
+            return ret;
+          })
       .def_property_readonly("pos_width",
                              mlirSparseTensorEncodingAttrGetPosWidth)
       .def_property_readonly("crd_width",
diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
index bf3a4ad5e7a1683..309d5ff5fedb90e 100644
--- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp
+++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
@@ -48,15 +48,17 @@ bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
 MlirAttribute
 mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank,
                                 MlirSparseTensorDimLevelType const *lvlTypes,
-                                MlirAffineMap dimToLvl, int posWidth,
-                                int crdWidth) {
+                                MlirAffineMap dimToLvl, MlirAffineMap lvlToDim,
+                                int posWidth, int crdWidth) {
   SmallVector<DimLevelType> cppLvlTypes;
   cppLvlTypes.reserve(lvlRank);
   for (intptr_t l = 0; l < lvlRank; ++l)
     cppLvlTypes.push_back(static_cast<DimLevelType>(lvlTypes[l]));
-  mlir::AffineMap lvlToDim; // TODO: provide in API
+  auto unwrappedLvlToDim = unwrap(lvlToDim);
+  if (!unwrappedLvlToDim)
+    unwrappedLvlToDim = inferLvlToDim(unwrap(dimToLvl), unwrap(ctx));
   return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppLvlTypes,
-                                            unwrap(dimToLvl), lvlToDim,
+                                            unwrap(dimToLvl), unwrappedLvlToDim,
                                             posWidth, crdWidth));
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 61522fb0dcd24b5..212de502640ef7b 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -293,9 +293,8 @@ Type SparseTensorEncodingAttr::getCrdType() const {
 SparseTensorEncodingAttr
 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
-  // TODO: infer lvlToDim
   return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), dimToLvl,
-                                       /*lvlToDim*/ AffineMap(), getPosWidth(),
+                                       getLvlToDim(), getPosWidth(),
                                        getCrdWidth());
 }
 
@@ -583,7 +582,7 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
 #undef RETURN_ON_FAIL
 
   // Construct struct-like storage for attribute.
-  AffineMap lvlToDim; // TODO: infer
+  AffineMap lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
   return parser.getChecked<SparseTensorEncodingAttr>(
       parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
       dimSlices);
@@ -749,6 +748,71 @@ mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
   return nullptr;
 }
 
+AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl,
+                                             MLIRContext *context) {
+  auto map = static_cast<AffineMap>(dimToLvl);
+  AffineMap lvlToDim;
+  // TODO: support ELL instead of returning an empty lvlToDim.
+  if (!map || map.getNumSymbols() != 0) {
+    lvlToDim = AffineMap();
+  } else if (map.isPermutation()) {
+    lvlToDim = inversePermutation(map);
+  } else {
+    lvlToDim = inverseBlockSparsity(map, context);
+  }
+  return lvlToDim;
+}
+
+AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl,
+                                                    MLIRContext *context) {
+  SmallVector<AffineExpr> lvlExprs;
+  auto numLvls = dimToLvl.getNumResults();
+  lvlExprs.reserve(numLvls);
+  // lvlExprComponents stores information of the floordiv and mod operations
+  // applied to the same dimension, so as to build the lvlToDim map.
+  // Map key is the position of the dimension in dimToLvl.
+  // Map value is a SmallVector that contains lvl var for floordiv, multiplier,
+  // lvl var for mod in dimToLvl.
+  // For example, for il = i floordiv 2 and ii = i mod 2, the SmalleVector
+  // would be [il, 2, ii]. It could be used to build the AffineExpr
+  // i = il * 2 + ii in lvlToDim.
+  std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
+  for (unsigned i = 0, n = numLvls; i < n; i++) {
+    auto result = dimToLvl.getResult(i);
+    if (auto binOp = result.dyn_cast<AffineBinaryOpExpr>()) {
+      if (result.getKind() == AffineExprKind::FloorDiv) {
+        SmallVector<AffineExpr, 3> components;
+        // Level variable for floordiv.
+        components.push_back(getAffineDimExpr(i, context));
+        // Multiplier.
+        components.push_back(binOp.getRHS());
+        auto pos = binOp.getLHS().dyn_cast<AffineDimExpr>().getPosition();
+        lvlExprComponents[pos] = components;
+      } else if (result.getKind() == AffineExprKind::Mod) {
+        auto pos = binOp.getLHS().dyn_cast<AffineDimExpr>().getPosition();
+        assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
+               "expected floordiv before mod");
+        // Level variable for mod.
+        lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
+      } else {
+        assert(false && "expected floordiv or mod");
+      }
+    } else {
+      lvlExprs.push_back(getAffineDimExpr(i, context));
+    }
+  }
+  for (auto &components : lvlExprComponents) {
+    assert(components.second.size() == 3 &&
+           "expected 3 components to build lvlExprs");
+    auto mulOp = getAffineBinaryOpExpr(
+        AffineExprKind::Mul, components.second[0], components.second[1]);
+    auto addOp =
+        getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]);
+    lvlExprs.push_back(addOp);
+  }
+  return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context);
+}
+
 bool mlir::sparse_tensor::isCOOType(SparseTensorEncodingAttr enc,
                                     Level startLvl, bool isUnique) {
   if (!enc ||
@@ -811,7 +875,7 @@ RankedTensorType sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType rtt,
   // default value.
   unsigned posWidth = src.getPosWidth();
   unsigned crdWidth = src.getCrdWidth();
-  AffineMap invPerm; // TODO
+  auto invPerm = src.getLvlToDim();
   auto enc = SparseTensorEncodingAttr::get(src.getContext(), lvlTypes, lvlPerm,
                                            invPerm, posWidth, crdWidth);
   return RankedTensorType::get(src.getDimShape(), src.getElementType(), enc);
diff --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c
index 33ee8e784096a18..3bd1508cf299a3d 100644
--- a/mlir/test/CAPI/sparse_tensor.c
+++ b/mlir/test/CAPI/sparse_tensor.c
@@ -40,6 +40,8 @@ static int testRoundtripEncoding(MlirContext ctx) {
   // CHECK: level_type: 4
   // CHECK: level_type: 8
   // CHECK: level_type: 8
+  MlirAffineMap lvlToDim =
+      mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr);
   int lvlRank = mlirSparseTensorEncodingGetLvlRank(originalAttr);
   enum MlirSparseTensorDimLevelType *lvlTypes =
       malloc(sizeof(enum MlirSparseTensorDimLevelType) * lvlRank);
@@ -53,9 +55,8 @@ static int testRoundtripEncoding(MlirContext ctx) {
   // CHECK: crdWidth: 64
   int crdWidth = mlirSparseTensorEncodingAttrGetCrdWidth(originalAttr);
   fprintf(stderr, "crdWidth: %d\n", crdWidth);
-  // TODO: lvlToDim
   MlirAttribute newAttr = mlirSparseTensorEncodingAttrGet(
-      ctx, lvlRank, lvlTypes, dimToLvl, posWidth, crdWidth);
+      ctx, lvlRank, lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth);
   mlirAttributeDump(newAttr); // For debugging filecheck output.
   // CHECK: equal: 1
   fprintf(stderr, "equal: %d\n", mlirAttributeEqual(originalAttr, newAttr));
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index ae3805d8b774176..ea8217ab6e3f233 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -160,6 +160,24 @@ func.func private @BSR(%arg0: tensor<?x?xf64, #BSR>) {
 
 // -----
 
+#BCSR = #sparse_tensor.encoding<{
+  map = ( i, j, k ) ->
+  ( i floordiv 2 : dense,
+    j floordiv 3 : dense,
+    k floordiv 4 : compressed,
+    i mod 2      : dense,
+    j mod 3      : dense,
+    k mod 4      : dense
+  )
+}>
+
+// CHECK-LABEL: func private @BCSR(
+// CHECK-SAME: tensor<?x?x?xf64, #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 floordiv 2 : dense, d1 floordiv 3 : dense, d2 floordiv 4 : compressed, d0 mod 2 : dense, d1 mod 3 : dense, d2 mod 4 : dense) }>>
+func.func private @BCSR(%arg0: tensor<?x?x?xf64, #BCSR>) {
+  return
+}
+// -----
+
 #BSR_explicit = #sparse_tensor.encoding<{
   map =
   {il, jl, ii, jj}
@@ -194,3 +212,37 @@ func.func private @BSR_explicit(%arg0: tensor<?x?xf64, #BSR_explicit>) {
 func.func private @NV_24(%arg0: tensor<?x?xf64, #NV_24>) {
   return
 }
+
+// -----
+
+#NV_24 = #sparse_tensor.encoding<{
+  map = ( i, j, k ) ->
+  ( i            : dense,
+    j            : dense,
+    k floordiv 4 : dense,
+    k mod 4      : block2_4
+  )
+}>
+
+// CHECK-LABEL: func private @NV_24(
+// CHECK-SAME: tensor<?x?x?xf64, #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 floordiv 4 : dense, d2 mod 4 : block2_4) }>>
+func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
+  return
+}
+
+// -----
+
+#NV_24 = #sparse_tensor.encoding<{
+  map = ( i, j, k ) ->
+  ( i            : dense,
+    k floordiv 4 : dense,
+    j            : dense,
+    k mod 4      : block2_4
+  )
+}>
+
+// CHECK-LABEL: func private @NV_24(
+// CHECK-SAME: tensor<?x?x?xf64, #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d2 floordiv 4 : dense, d1 : dense, d2 mod 4 : block2_4) }>>
+func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
+  return
+}
\ No newline at end of file
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py
index 0cdc7c88bd97fb8..1f9b6360383180c 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py
@@ -155,7 +155,7 @@ def main():
                     for iwidth in [32]:
                         for e in [True]:
                             attr = st.EncodingAttr.get(
-                                level, ordering, pwidth, iwidth
+                                level, ordering, None, pwidth, iwidth
                             )
                             opt = f"parallelization-strategy=none"
                             compiler = sparse_compiler.SparseCompiler(
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py
index 01d74a4dc82fa1d..69f6cdcea967fae 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py
@@ -145,7 +145,7 @@ def main():
                 for pwidth in bitwidths:
                     for iwidth in bitwidths:
                         attr = st.EncodingAttr.get(
-                            level, ordering, pwidth, iwidth
+                            level, ordering, None, pwidth, iwidth
                         )
                         build_compile_and_run_SpMM(attr, compiler)
                         count = count + 1
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
index 8f3f4e5af1e58ef..b8ef614e04dfad2 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
@@ -91,7 +91,8 @@ def main():
         for level in levels:
             for ordering in orderings:
                 for bwidth in bitwidths:
-                    attr = st.EncodingAttr.get(level, ordering, bwidth, bwidth)
+                    attr = st.EncodingAttr.get(level, ordering, None,
+                                               bwidth, bwidth)
                     build_compile_and_run_output(attr, compiler)
                     count = count + 1
 
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
index ef266672ce42afc..841b02bc10c8bec 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
@@ -227,7 +227,7 @@ def main():
                 for pwidth in bitwidths:
                     for iwidth in bitwidths:
                         attr = st.EncodingAttr.get(
-                            level, ordering, pwidth, iwidth
+                            level, ordering, None, pwidth, iwidth
                         )
                         types.append(ir.RankedTensorType.get(shape, f64, attr))
         #
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index d80b878323377a4..10b390273cc91f3 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -32,12 +32,14 @@ def testEncodingAttr1D():
         print(f"lvl_types: {casted.lvl_types}")
         # CHECK: dim_to_lvl: None
         print(f"dim_to_lvl: {casted.dim_to_lvl}")
+        # CHECK: lvl_to_dim: None
+        print(f"lvl_to_dim: {casted.lvl_to_dim}")
         # CHECK: pos_width: 16
         print(f"pos_width: {casted.pos_width}")
         # CHECK: crd_width: 32
         print(f"crd_width: {casted.crd_width}")
 
-        created = st.EncodingAttr.get(casted.lvl_types, None, 0, 0)
+        created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0)
         # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
         print(created)
         # CHECK: created_equal: False
@@ -72,12 +74,16 @@ def testEncodingAttr2D():
         print(f"lvl_types: {casted.lvl_types}")
         # CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
         print(f"dim_to_lvl: {casted.dim_to_lvl}")
+        # CHECK: lvl_to_dim: (d0, d1) -> (d1, d0)
+        print(f"lvl_to_dim: {casted.lvl_to_dim}")
         # CHECK: pos_width: 8
         print(f"pos_width: {casted.pos_width}")
         # CHECK: crd_width: 32
         print(f"crd_width: {casted.crd_width}")
 
-        created = st.EncodingAttr.get(casted.lvl_types, casted.dim_to_lvl, 8, 32)
+        created = st.EncodingAttr.get(
+            casted.lvl_types, casted.dim_to_lvl, casted.lvl_to_dim, 8, 32,
+        )
         # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
         print(created)
         # CHECK: created_equal: True

>From ec4cbb0176cb9752bc306d43640f03e1700754ec Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 12 Oct 2023 23:34:00 +0000
Subject: [PATCH 2/3] format

---
 mlir/lib/Bindings/Python/DialectSparseTensor.cpp            | 4 ++--
 .../Integration/Dialect/SparseTensor/python/test_output.py  | 3 +--
 mlir/test/python/dialects/sparse_tensor/dialect.py          | 6 +++++-
 3 files changed, 8 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index ba449faa9fe4262..9bde3a443ecfeca 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -51,8 +51,8 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
                 crdWidth));
           },
           py::arg("cls"), py::arg("lvl_types"), py::arg("dim_to_lvl"),
-          py::arg("lvl_to_dim"), py::arg("pos_width"),
-          py::arg("crd_width"), py::arg("context") = py::none(),
+          py::arg("lvl_to_dim"), py::arg("pos_width"), py::arg("crd_width"),
+          py::arg("context") = py::none(),
           "Gets a sparse_tensor.encoding from parameters.")
       .def_property_readonly(
           "lvl_types",
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
index b8ef614e04dfad2..7d7749008020515 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
@@ -91,8 +91,7 @@ def main():
         for level in levels:
             for ordering in orderings:
                 for bwidth in bitwidths:
-                    attr = st.EncodingAttr.get(level, ordering, None,
-                                               bwidth, bwidth)
+                    attr = st.EncodingAttr.get(level, ordering, None, bwidth, bwidth)
                     build_compile_and_run_output(attr, compiler)
                     count = count + 1
 
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 10b390273cc91f3..240db6ebd1d1eb3 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -82,7 +82,11 @@ def testEncodingAttr2D():
         print(f"crd_width: {casted.crd_width}")
 
         created = st.EncodingAttr.get(
-            casted.lvl_types, casted.dim_to_lvl, casted.lvl_to_dim, 8, 32,
+            casted.lvl_types,
+            casted.dim_to_lvl,
+            casted.lvl_to_dim,
+            8,
+            32,
         )
         # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
         print(created)

>From e3ac322d83cfa35429ea517fa129b3d2826d224d Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Mon, 16 Oct 2023 19:54:33 +0000
Subject: [PATCH 3/3] address comments

---
 .../SparseTensor/IR/SparseTensorAttrDefs.td   |  3 +++
 mlir/lib/CAPI/Dialect/SparseTensor.cpp        |  5 +----
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 22 +++++++++++--------
 3 files changed, 17 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 38c7200afb41ffc..47fd18a689d5a8d 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -307,6 +307,9 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
                      "AffineMap":$lvlToDim,
                      "unsigned":$posWidth,
                      "unsigned":$crdWidth), [{
+      if (!lvlToDim) {
+        lvlToDim = ::mlir::sparse_tensor::inferLvlToDim(dimToLvl, $_ctxt);
+      }
       return $_get($_ctxt, lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
         ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr>{});
     }]>
diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
index 309d5ff5fedb90e..c3ad95527df489f 100644
--- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp
+++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
@@ -54,11 +54,8 @@ mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank,
   cppLvlTypes.reserve(lvlRank);
   for (intptr_t l = 0; l < lvlRank; ++l)
     cppLvlTypes.push_back(static_cast<DimLevelType>(lvlTypes[l]));
-  auto unwrappedLvlToDim = unwrap(lvlToDim);
-  if (!unwrappedLvlToDim)
-    unwrappedLvlToDim = inferLvlToDim(unwrap(dimToLvl), unwrap(ctx));
   return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppLvlTypes,
-                                            unwrap(dimToLvl), unwrappedLvlToDim,
+                                            unwrap(dimToLvl), unwrap(lvlToDim),
                                             posWidth, crdWidth));
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 212de502640ef7b..bcacd3be2101520 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -582,6 +582,7 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
 #undef RETURN_ON_FAIL
 
   // Construct struct-like storage for attribute.
+  // TODO: Fetch lvlToDim if user provides one
   AffineMap lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
   return parser.getChecked<SparseTensorEncodingAttr>(
       parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
@@ -770,29 +771,28 @@ AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl,
   lvlExprs.reserve(numLvls);
   // lvlExprComponents stores information of the floordiv and mod operations
   // applied to the same dimension, so as to build the lvlToDim map.
-  // Map key is the position of the dimension in dimToLvl.
-  // Map value is a SmallVector that contains lvl var for floordiv, multiplier,
-  // lvl var for mod in dimToLvl.
-  // For example, for il = i floordiv 2 and ii = i mod 2, the SmalleVector
-  // would be [il, 2, ii]. It could be used to build the AffineExpr
-  // i = il * 2 + ii in lvlToDim.
   std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
   for (unsigned i = 0, n = numLvls; i < n; i++) {
     auto result = dimToLvl.getResult(i);
     if (auto binOp = result.dyn_cast<AffineBinaryOpExpr>()) {
       if (result.getKind() == AffineExprKind::FloorDiv) {
+        // Position of the dimension in dimToLvl.
+        auto pos = binOp.getLHS().dyn_cast<AffineDimExpr>().getPosition();
+        assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
+               "expected only one floordiv for each dimension");
         SmallVector<AffineExpr, 3> components;
         // Level variable for floordiv.
         components.push_back(getAffineDimExpr(i, context));
         // Multiplier.
         components.push_back(binOp.getRHS());
-        auto pos = binOp.getLHS().dyn_cast<AffineDimExpr>().getPosition();
+        // Map key is the position of the dimension.
         lvlExprComponents[pos] = components;
       } else if (result.getKind() == AffineExprKind::Mod) {
         auto pos = binOp.getLHS().dyn_cast<AffineDimExpr>().getPosition();
         assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
                "expected floordiv before mod");
-        // Level variable for mod.
+        // Level variable for mod added to the vector of the corresponding
+        // floordiv with the same dimension.
         lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
       } else {
         assert(false && "expected floordiv or mod");
@@ -801,6 +801,10 @@ AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl,
       lvlExprs.push_back(getAffineDimExpr(i, context));
     }
   }
+  // Build lvlExprs from lvlExprComponents.
+  // For example, for il = i floordiv 2 and ii = i mod 2, the components
+  // would be [il, 2, ii]. It could be used to build the AffineExpr
+  // i = il * 2 + ii in lvlToDim.
   for (auto &components : lvlExprComponents) {
     assert(components.second.size() == 3 &&
            "expected 3 components to build lvlExprs");
@@ -875,7 +879,7 @@ RankedTensorType sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType rtt,
   // default value.
   unsigned posWidth = src.getPosWidth();
   unsigned crdWidth = src.getCrdWidth();
-  auto invPerm = src.getLvlToDim();
+  AffineMap invPerm = src.getLvlToDim();
   auto enc = SparseTensorEncodingAttr::get(src.getContext(), lvlTypes, lvlPerm,
                                            invPerm, posWidth, crdWidth);
   return RankedTensorType::get(src.getDimShape(), src.getElementType(), enc);



More information about the Mlir-commits mailing list