[Mlir-commits] [mlir] [mlir][sparse] add lvlToDim field to sparse tensor encoding (PR #67194)

Aart Bik llvmlistbot at llvm.org
Fri Sep 22 13:51:10 PDT 2023


https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/67194

Note the new surface syntax allows for defining a dimToLvl and lvlToDim map at once (where usually the latter can be inferred from the former, but not always). This revision adds storage for the latter, together with some intial boilerplate. The actual support (inference, validation, printing, etc.) is still TBD of course.

>From 26cdb38291a4d9fc38212d26208fbffaaa8d09d9 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Fri, 22 Sep 2023 13:39:17 -0700
Subject: [PATCH] [mlir][sparse] add lvlToDim field to sparse tensor encoding

Note the new surface syntax allows for defining a dimToLvl and
lvlToDim map at once (where usually the latter can be inferred
from the former, but not always). This revision adds storage
for the latter, together with some intial boilerplate. The
actual support (inference, validation, printing, etc.) is
still TBD of course.
---
 mlir/include/mlir-c/Dialect/SparseTensor.h    |  6 ++++
 .../SparseTensor/IR/SparseTensorAttrDefs.td   |  9 +++--
 .../Bindings/Python/DialectSparseTensor.cpp   |  1 +
 mlir/lib/CAPI/Dialect/SparseTensor.cpp        | 10 ++++--
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 34 ++++++++++++-------
 .../Transforms/SparseTensorConversion.cpp     |  2 +-
 mlir/test/CAPI/sparse_tensor.c                |  3 +-
 7 files changed, 46 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index b2e4b96c65019c5..fecbeaf6b0f9d6c 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -51,6 +51,7 @@ MLIR_CAPI_EXPORTED bool
 mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr);
 
 /// Creates a `sparse_tensor.encoding` attribute with the given parameters.
+/// TODO: add a version that supplied lvlToDim when it cannot be inferred
 MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet(
     MlirContext ctx, intptr_t lvlRank,
     enum MlirSparseTensorDimLevelType const *lvlTypes, MlirAffineMap dimToLvl,
@@ -69,6 +70,11 @@ mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl);
 MLIR_CAPI_EXPORTED MlirAffineMap
 mlirSparseTensorEncodingAttrGetDimToLvl(MlirAttribute attr);
 
+/// Returns the level-to-dimension mapping of the `sparse_tensor.encoding`
+/// attribute.
+MLIR_CAPI_EXPORTED MlirAffineMap
+mlirSparseTensorEncodingAttrGetLvlToDim(MlirAttribute attr);
+
 /// Returns the position bitwidth of the `sparse_tensor.encoding` attribute.
 MLIR_CAPI_EXPORTED int
 mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr);
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index cc12f4bfc91a6d5..6a8e66335fbc697 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -237,7 +237,8 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
   //
   let parameters = (
     ins
-    // A level-type for each level of the sparse storage.
+    // A level-type for each level of the sparse storage
+    // (consists of a level-format combined with level-properties).
     ArrayRefParameter<
       "::mlir::sparse_tensor::DimLevelType",
       "level-types"
@@ -246,6 +247,9 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     // A mapping from dimension-coordinates to level-coordinates.
     "AffineMap":$dimToLvl,
 
+    // A mapping from level-coordinates to dimension-coordinates.
+    "AffineMap":$lvlToDim,
+
     // The required bitwidth for position storage.
     "unsigned":$posWidth,
 
@@ -262,9 +266,10 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
   let builders = [
     AttrBuilder<(ins "ArrayRef<::mlir::sparse_tensor::DimLevelType>":$lvlTypes,
                      "AffineMap":$dimToLvl,
+                     "AffineMap":$lvlToDim,
                      "unsigned":$posWidth,
                      "unsigned":$crdWidth), [{
-      return $_get($_ctxt, lvlTypes, dimToLvl, posWidth, crdWidth,
+      return $_get($_ctxt, lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
         ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr>{});
     }]>
   ];
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index 70805005f8afa87..3061e042c851d97 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -43,6 +43,7 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
           [](py::object cls, std::vector<MlirSparseTensorDimLevelType> lvlTypes,
              std::optional<MlirAffineMap> dimToLvl, int posWidth, int crdWidth,
              MlirContext context) {
+            // TODO: provide dimToLvl
             return cls(mlirSparseTensorEncodingAttrGet(
                 context, lvlTypes.size(), lvlTypes.data(),
                 dimToLvl ? *dimToLvl : MlirAffineMap{nullptr}, posWidth,
diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
index e18da1027e0f33a..bf3a4ad5e7a1683 100644
--- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp
+++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
@@ -54,14 +54,20 @@ mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank,
   cppLvlTypes.reserve(lvlRank);
   for (intptr_t l = 0; l < lvlRank; ++l)
     cppLvlTypes.push_back(static_cast<DimLevelType>(lvlTypes[l]));
-  return wrap(SparseTensorEncodingAttr::get(
-      unwrap(ctx), cppLvlTypes, unwrap(dimToLvl), posWidth, crdWidth));
+  mlir::AffineMap lvlToDim; // TODO: provide in API
+  return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppLvlTypes,
+                                            unwrap(dimToLvl), lvlToDim,
+                                            posWidth, crdWidth));
 }
 
 MlirAffineMap mlirSparseTensorEncodingAttrGetDimToLvl(MlirAttribute attr) {
   return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getDimToLvl());
 }
 
+MlirAffineMap mlirSparseTensorEncodingAttrGetLvlToDim(MlirAttribute attr) {
+  return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlToDim());
+}
+
 intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) {
   return cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlRank();
 }
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 9675a61109477b5..1c75df41e33daa4 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -293,8 +293,10 @@ Type SparseTensorEncodingAttr::getCrdType() const {
 SparseTensorEncodingAttr
 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
+  // TODO: infer lvlToDim
   return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), dimToLvl,
-                                       getPosWidth(), getCrdWidth());
+                                       /*lvlToDim*/ AffineMap(), getPosWidth(),
+                                       getCrdWidth());
 }
 
 SparseTensorEncodingAttr
@@ -311,7 +313,8 @@ SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
                                         unsigned crdWidth) const {
   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
   return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
-                                       getDimToLvl(), posWidth, crdWidth);
+                                       getDimToLvl(), getLvlToDim(), posWidth,
+                                       crdWidth);
 }
 
 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
@@ -321,8 +324,8 @@ SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
     ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
   return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
-                                       getDimToLvl(), getPosWidth(),
-                                       getCrdWidth(), dimSlices);
+                                       getDimToLvl(), getLvlToDim(),
+                                       getPosWidth(), getCrdWidth(), dimSlices);
 }
 
 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
@@ -576,8 +579,10 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
 #undef RETURN_ON_FAIL
 
   // Construct struct-like storage for attribute.
+  AffineMap lvlToDim; // TODO: infer
   return parser.getChecked<SparseTensorEncodingAttr>(
-      parser.getContext(), lvlTypes, dimToLvl, posWidth, crdWidth, dimSlices);
+      parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
+      dimSlices);
 }
 
 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
@@ -608,10 +613,12 @@ void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
   printer << " }>";
 }
 
-LogicalResult SparseTensorEncodingAttr::verify(
-    function_ref<InFlightDiagnostic()> emitError,
-    ArrayRef<DimLevelType> lvlTypes, AffineMap dimToLvl, unsigned posWidth,
-    unsigned crdWidth, ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
+LogicalResult
+SparseTensorEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                                 ArrayRef<DimLevelType> lvlTypes,
+                                 AffineMap dimToLvl, AffineMap lvlToDim,
+                                 unsigned posWidth, unsigned crdWidth,
+                                 ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
   if (!acceptBitWidth(posWidth))
     return emitError() << "unexpected position bitwidth: " << posWidth;
   if (!acceptBitWidth(crdWidth))
@@ -631,7 +638,7 @@ LogicalResult SparseTensorEncodingAttr::verify(
       return emitError()
              << "level-rank mismatch between dimToLvl and lvlTypes: "
              << dimToLvl.getNumResults() << " != " << lvlRank;
-    // TODO: The following is attempting to match the old error-conditions
+    // TODO:  The following is attempting to match the old error-conditions
     // from prior to merging dimOrdering and higherOrdering into dimToLvl.
     // That is, we currently require `dimToLvl` to be either a permutation
     // (as when higherOrdering is the identity) or expansive (as per the
@@ -674,7 +681,8 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
   // Check structural integrity.  In particular, this ensures that the
   // level-rank is coherent across all the fields.
   RETURN_FAILURE_IF_FAILED(verify(emitError, getLvlTypes(), getDimToLvl(),
-                                  getPosWidth(), getCrdWidth(), getDimSlices()))
+                                  getLvlToDim(), getPosWidth(), getCrdWidth(),
+                                  getDimSlices()))
   // Check integrity with tensor type specifics.  In particular, we
   // need only check that the dimension-rank of the tensor agrees with
   // the dimension-rank of the encoding.
@@ -763,8 +771,9 @@ RankedTensorType sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType rtt,
   // default value.
   unsigned posWidth = src.getPosWidth();
   unsigned crdWidth = src.getCrdWidth();
+  AffineMap invPerm; // TODO
   auto enc = SparseTensorEncodingAttr::get(src.getContext(), lvlTypes, lvlPerm,
-                                           posWidth, crdWidth);
+                                           invPerm, posWidth, crdWidth);
   return RankedTensorType::get(src.getDimShape(), src.getElementType(), enc);
 }
 
@@ -836,6 +845,7 @@ getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
   return SparseTensorEncodingAttr::get(
       enc.getContext(), dlts,
       AffineMap(), // dimToLvl (irrelevant to storage specifier)
+      AffineMap(), // lvlToDim (irrelevant to storage specifier)
       // Always use `index` for memSize and lvlSize instead of reusing
       // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
       // value for different bitwidth, it also avoids casting between index and
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index d75601e369a0d25..9afd805cf87f10d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -986,7 +986,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
       const auto dstEnc = SparseTensorEncodingAttr::get(
           op->getContext(),
           SmallVector<DimLevelType>(dimRank, DimLevelType::Dense), AffineMap(),
-          srcEnc.getPosWidth(), srcEnc.getCrdWidth());
+          AffineMap(), srcEnc.getPosWidth(), srcEnc.getCrdWidth());
       SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src);
       Value iter = NewCallParams(rewriter, loc)
                        .genBuffers(dstTp.withEncoding(dstEnc), dimSizes)
diff --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c
index 6449a8f0c79403c..30ef1557e73302f 100644
--- a/mlir/test/CAPI/sparse_tensor.c
+++ b/mlir/test/CAPI/sparse_tensor.c
@@ -54,13 +54,12 @@ static int testRoundtripEncoding(MlirContext ctx) {
   // CHECK: crdWidth: 64
   int crdWidth = mlirSparseTensorEncodingAttrGetCrdWidth(originalAttr);
   fprintf(stderr, "crdWidth: %d\n", crdWidth);
-
+  // TODO: lvlToDim
   MlirAttribute newAttr = mlirSparseTensorEncodingAttrGet(
       ctx, lvlRank, lvlTypes, dimToLvl, posWidth, crdWidth);
   mlirAttributeDump(newAttr); // For debugging filecheck output.
   // CHECK: equal: 1
   fprintf(stderr, "equal: %d\n", mlirAttributeEqual(originalAttr, newAttr));
-
   free(lvlTypes);
   return 0;
 }



More information about the Mlir-commits mailing list