[Mlir-commits] [mlir] [mlir][sparse] add lvlToDim field to sparse tensor encoding (PR #67194)

Aart Bik llvmlistbot at llvm.org
Fri Sep 22 15:48:46 PDT 2023


https://github.com/aartbik updated https://github.com/llvm/llvm-project/pull/67194

>From 26cdb38291a4d9fc38212d26208fbffaaa8d09d9 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Fri, 22 Sep 2023 13:39:17 -0700
Subject: [PATCH 1/3] [mlir][sparse] add lvlToDim field to sparse tensor
 encoding

Note the new surface syntax allows for defining a dimToLvl and
lvlToDim map at once (where usually the latter can be inferred
from the former, but not always). This revision adds storage
for the latter, together with some intial boilerplate. The
actual support (inference, validation, printing, etc.) is
still TBD of course.
---
 mlir/include/mlir-c/Dialect/SparseTensor.h    |  6 ++++
 .../SparseTensor/IR/SparseTensorAttrDefs.td   |  9 +++--
 .../Bindings/Python/DialectSparseTensor.cpp   |  1 +
 mlir/lib/CAPI/Dialect/SparseTensor.cpp        | 10 ++++--
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 34 ++++++++++++-------
 .../Transforms/SparseTensorConversion.cpp     |  2 +-
 mlir/test/CAPI/sparse_tensor.c                |  3 +-
 7 files changed, 46 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index b2e4b96c65019c5..fecbeaf6b0f9d6c 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -51,6 +51,7 @@ MLIR_CAPI_EXPORTED bool
 mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr);
 
 /// Creates a `sparse_tensor.encoding` attribute with the given parameters.
+/// TODO: add a version that supplied lvlToDim when it cannot be inferred
 MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet(
     MlirContext ctx, intptr_t lvlRank,
     enum MlirSparseTensorDimLevelType const *lvlTypes, MlirAffineMap dimToLvl,
@@ -69,6 +70,11 @@ mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl);
 MLIR_CAPI_EXPORTED MlirAffineMap
 mlirSparseTensorEncodingAttrGetDimToLvl(MlirAttribute attr);
 
+/// Returns the level-to-dimension mapping of the `sparse_tensor.encoding`
+/// attribute.
+MLIR_CAPI_EXPORTED MlirAffineMap
+mlirSparseTensorEncodingAttrGetLvlToDim(MlirAttribute attr);
+
 /// Returns the position bitwidth of the `sparse_tensor.encoding` attribute.
 MLIR_CAPI_EXPORTED int
 mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr);
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index cc12f4bfc91a6d5..6a8e66335fbc697 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -237,7 +237,8 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
   //
   let parameters = (
     ins
-    // A level-type for each level of the sparse storage.
+    // A level-type for each level of the sparse storage
+    // (consists of a level-format combined with level-properties).
     ArrayRefParameter<
       "::mlir::sparse_tensor::DimLevelType",
       "level-types"
@@ -246,6 +247,9 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     // A mapping from dimension-coordinates to level-coordinates.
     "AffineMap":$dimToLvl,
 
+    // A mapping from level-coordinates to dimension-coordinates.
+    "AffineMap":$lvlToDim,
+
     // The required bitwidth for position storage.
     "unsigned":$posWidth,
 
@@ -262,9 +266,10 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
   let builders = [
     AttrBuilder<(ins "ArrayRef<::mlir::sparse_tensor::DimLevelType>":$lvlTypes,
                      "AffineMap":$dimToLvl,
+                     "AffineMap":$lvlToDim,
                      "unsigned":$posWidth,
                      "unsigned":$crdWidth), [{
-      return $_get($_ctxt, lvlTypes, dimToLvl, posWidth, crdWidth,
+      return $_get($_ctxt, lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
         ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr>{});
     }]>
   ];
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index 70805005f8afa87..3061e042c851d97 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -43,6 +43,7 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
           [](py::object cls, std::vector<MlirSparseTensorDimLevelType> lvlTypes,
              std::optional<MlirAffineMap> dimToLvl, int posWidth, int crdWidth,
              MlirContext context) {
+            // TODO: provide dimToLvl
             return cls(mlirSparseTensorEncodingAttrGet(
                 context, lvlTypes.size(), lvlTypes.data(),
                 dimToLvl ? *dimToLvl : MlirAffineMap{nullptr}, posWidth,
diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
index e18da1027e0f33a..bf3a4ad5e7a1683 100644
--- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp
+++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
@@ -54,14 +54,20 @@ mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank,
   cppLvlTypes.reserve(lvlRank);
   for (intptr_t l = 0; l < lvlRank; ++l)
     cppLvlTypes.push_back(static_cast<DimLevelType>(lvlTypes[l]));
-  return wrap(SparseTensorEncodingAttr::get(
-      unwrap(ctx), cppLvlTypes, unwrap(dimToLvl), posWidth, crdWidth));
+  mlir::AffineMap lvlToDim; // TODO: provide in API
+  return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppLvlTypes,
+                                            unwrap(dimToLvl), lvlToDim,
+                                            posWidth, crdWidth));
 }
 
 MlirAffineMap mlirSparseTensorEncodingAttrGetDimToLvl(MlirAttribute attr) {
   return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getDimToLvl());
 }
 
+MlirAffineMap mlirSparseTensorEncodingAttrGetLvlToDim(MlirAttribute attr) {
+  return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlToDim());
+}
+
 intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) {
   return cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlRank();
 }
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 9675a61109477b5..1c75df41e33daa4 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -293,8 +293,10 @@ Type SparseTensorEncodingAttr::getCrdType() const {
 SparseTensorEncodingAttr
 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
+  // TODO: infer lvlToDim
   return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), dimToLvl,
-                                       getPosWidth(), getCrdWidth());
+                                       /*lvlToDim*/ AffineMap(), getPosWidth(),
+                                       getCrdWidth());
 }
 
 SparseTensorEncodingAttr
@@ -311,7 +313,8 @@ SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
                                         unsigned crdWidth) const {
   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
   return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
-                                       getDimToLvl(), posWidth, crdWidth);
+                                       getDimToLvl(), getLvlToDim(), posWidth,
+                                       crdWidth);
 }
 
 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
@@ -321,8 +324,8 @@ SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
     ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
   return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
-                                       getDimToLvl(), getPosWidth(),
-                                       getCrdWidth(), dimSlices);
+                                       getDimToLvl(), getLvlToDim(),
+                                       getPosWidth(), getCrdWidth(), dimSlices);
 }
 
 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
@@ -576,8 +579,10 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
 #undef RETURN_ON_FAIL
 
   // Construct struct-like storage for attribute.
+  AffineMap lvlToDim; // TODO: infer
   return parser.getChecked<SparseTensorEncodingAttr>(
-      parser.getContext(), lvlTypes, dimToLvl, posWidth, crdWidth, dimSlices);
+      parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
+      dimSlices);
 }
 
 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
@@ -608,10 +613,12 @@ void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
   printer << " }>";
 }
 
-LogicalResult SparseTensorEncodingAttr::verify(
-    function_ref<InFlightDiagnostic()> emitError,
-    ArrayRef<DimLevelType> lvlTypes, AffineMap dimToLvl, unsigned posWidth,
-    unsigned crdWidth, ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
+LogicalResult
+SparseTensorEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                                 ArrayRef<DimLevelType> lvlTypes,
+                                 AffineMap dimToLvl, AffineMap lvlToDim,
+                                 unsigned posWidth, unsigned crdWidth,
+                                 ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
   if (!acceptBitWidth(posWidth))
     return emitError() << "unexpected position bitwidth: " << posWidth;
   if (!acceptBitWidth(crdWidth))
@@ -631,7 +638,7 @@ LogicalResult SparseTensorEncodingAttr::verify(
       return emitError()
              << "level-rank mismatch between dimToLvl and lvlTypes: "
              << dimToLvl.getNumResults() << " != " << lvlRank;
-    // TODO: The following is attempting to match the old error-conditions
+    // TODO:  The following is attempting to match the old error-conditions
     // from prior to merging dimOrdering and higherOrdering into dimToLvl.
     // That is, we currently require `dimToLvl` to be either a permutation
     // (as when higherOrdering is the identity) or expansive (as per the
@@ -674,7 +681,8 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
   // Check structural integrity.  In particular, this ensures that the
   // level-rank is coherent across all the fields.
   RETURN_FAILURE_IF_FAILED(verify(emitError, getLvlTypes(), getDimToLvl(),
-                                  getPosWidth(), getCrdWidth(), getDimSlices()))
+                                  getLvlToDim(), getPosWidth(), getCrdWidth(),
+                                  getDimSlices()))
   // Check integrity with tensor type specifics.  In particular, we
   // need only check that the dimension-rank of the tensor agrees with
   // the dimension-rank of the encoding.
@@ -763,8 +771,9 @@ RankedTensorType sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType rtt,
   // default value.
   unsigned posWidth = src.getPosWidth();
   unsigned crdWidth = src.getCrdWidth();
+  AffineMap invPerm; // TODO
   auto enc = SparseTensorEncodingAttr::get(src.getContext(), lvlTypes, lvlPerm,
-                                           posWidth, crdWidth);
+                                           invPerm, posWidth, crdWidth);
   return RankedTensorType::get(src.getDimShape(), src.getElementType(), enc);
 }
 
@@ -836,6 +845,7 @@ getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
   return SparseTensorEncodingAttr::get(
       enc.getContext(), dlts,
       AffineMap(), // dimToLvl (irrelevant to storage specifier)
+      AffineMap(), // lvlToDim (irrelevant to storage specifier)
       // Always use `index` for memSize and lvlSize instead of reusing
       // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
       // value for different bitwidth, it also avoids casting between index and
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index d75601e369a0d25..9afd805cf87f10d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -986,7 +986,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
       const auto dstEnc = SparseTensorEncodingAttr::get(
           op->getContext(),
           SmallVector<DimLevelType>(dimRank, DimLevelType::Dense), AffineMap(),
-          srcEnc.getPosWidth(), srcEnc.getCrdWidth());
+          AffineMap(), srcEnc.getPosWidth(), srcEnc.getCrdWidth());
       SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src);
       Value iter = NewCallParams(rewriter, loc)
                        .genBuffers(dstTp.withEncoding(dstEnc), dimSizes)
diff --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c
index 6449a8f0c79403c..30ef1557e73302f 100644
--- a/mlir/test/CAPI/sparse_tensor.c
+++ b/mlir/test/CAPI/sparse_tensor.c
@@ -54,13 +54,12 @@ static int testRoundtripEncoding(MlirContext ctx) {
   // CHECK: crdWidth: 64
   int crdWidth = mlirSparseTensorEncodingAttrGetCrdWidth(originalAttr);
   fprintf(stderr, "crdWidth: %d\n", crdWidth);
-
+  // TODO: lvlToDim
   MlirAttribute newAttr = mlirSparseTensorEncodingAttrGet(
       ctx, lvlRank, lvlTypes, dimToLvl, posWidth, crdWidth);
   mlirAttributeDump(newAttr); // For debugging filecheck output.
   // CHECK: equal: 1
   fprintf(stderr, "equal: %d\n", mlirAttributeEqual(originalAttr, newAttr));
-
   free(lvlTypes);
   return 0;
 }

>From 1593181db7ec5aca8a96931ec1c46381d95a7537 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Fri, 22 Sep 2023 15:46:59 -0700
Subject: [PATCH 2/3] add tensor type wrapper too

---
 .../mlir/Dialect/SparseTensor/IR/SparseTensorType.h   | 11 ++++++++---
 1 file changed, 8 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index d9d6db46542a37a..3796d2ddf8d6b40 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -45,12 +45,13 @@ namespace sparse_tensor {
 ///
 class SparseTensorType {
 public:
-  // We memoize `lvlRank` and `dimToLvl` to avoid repeating the
-  // conditionals throughout the rest of the class.
+  // We memoize `lvlRank`, `dimToLvl`, and `lvlToDim1 to avoid repeating
+  // the conditionals throughout the rest of the class.
   SparseTensorType(RankedTensorType rtp)
       : rtp(rtp), enc(getSparseTensorEncoding(rtp)),
         lvlRank(enc ? enc.getLvlRank() : getDimRank()),
-        dimToLvl(enc.isIdentity() ? AffineMap() : enc.getDimToLvl()) {
+        dimToLvl(enc.isIdentity() ? AffineMap() : enc.getDimToLvl()),
+        lvlToDim(enc.isIdentity() ? AffineMap() : enc.getLvlToDim()) {
     assert(rtp && "got null RankedTensorType");
     assert((!isIdentity() || getDimRank() == lvlRank) && "Rank mismatch");
   }
@@ -201,6 +202,9 @@ class SparseTensorType {
   /// see `hasSameDimToLvl` instead.
   AffineMap getDimToLvl() const { return dimToLvl; }
 
+  /// Returns the lvlToDiml mapping (or the null-map for the identity).
+  AffineMap getLvlToDim() const { return lvlToDim; }
+
   /// Returns the dimToLvl mapping, where the identity map is expanded out
   /// into a full `AffineMap`.  This method is provided as a convenience,
   /// but for most purposes other methods (`isIdentity`, `getDimToLvl`,
@@ -306,6 +310,7 @@ class SparseTensorType {
   // Memoized to avoid frequent redundant conditionals.
   const Level lvlRank;
   const AffineMap dimToLvl;
+  const AffineMap lvlToDim;
 };
 
 /// Convenience method to abbreviate wrapping `getRankedTensorType`.

>From bca43f8ff5fad1edaee9d3ac0d33286c8ee45631 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Fri, 22 Sep 2023 15:48:28 -0700
Subject: [PATCH 3/3] typo

---
 mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 3796d2ddf8d6b40..b5dbc67781ce004 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -45,7 +45,7 @@ namespace sparse_tensor {
 ///
 class SparseTensorType {
 public:
-  // We memoize `lvlRank`, `dimToLvl`, and `lvlToDim1 to avoid repeating
+  // We memoize `lvlRank`, `dimToLvl`, and `lvlToDim` to avoid repeating
   // the conditionals throughout the rest of the class.
   SparseTensorType(RankedTensorType rtp)
       : rtp(rtp), enc(getSparseTensorEncoding(rtp)),



More information about the Mlir-commits mailing list