[Mlir-commits] [mlir] a10d67f - [mlir][sparse] Enable explicit and implicit value in sparse encoding (#88975)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 24 16:20:29 PDT 2024


Author: Yinying Li
Date: 2024-04-24T16:20:25-07:00
New Revision: a10d67f9fb559d0c35a12b2d26974636bbf642c0

URL: https://github.com/llvm/llvm-project/commit/a10d67f9fb559d0c35a12b2d26974636bbf642c0
DIFF: https://github.com/llvm/llvm-project/commit/a10d67f9fb559d0c35a12b2d26974636bbf642c0.diff

LOG: [mlir][sparse] Enable explicit and implicit value in sparse encoding (#88975)

1. Explicit value means the non-zero value in a sparse tensor. If
explicitVal is set, then all the non-zero values in the tensor have the
same explicit value. The default value Attribute() indicates that it is
not set.

2. Implicit value means the "zero" value in a sparse tensor. If
implicitVal is set, then the "zero" value in the tensor is equal to the
implicit value. For now, we only support `0` as the implicit value but
it could be extended in the future. The default value Attribute()
indicates that the implicit value is `0` (same type as the tensor
element type).

Example:

```
#CSR = #sparse_tensor.encoding<{
  map = (d0, d1) -> (d0 : dense, d1 : compressed),
  posWidth = 64,
  crdWidth = 64,
  explicitVal = 1 : i64,
  implicitVal = 0 : i64
}>
```

Note: this PR tests that implicitVal could be set to other values as
well. The following PR will add verifier and reject any value that's not
zero for implicitVal.

Added: 
    

Modified: 
    mlir/include/mlir-c/Dialect/SparseTensor.h
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
    mlir/lib/Bindings/Python/DialectSparseTensor.cpp
    mlir/lib/CAPI/Dialect/SparseTensor.cpp
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/test/CAPI/sparse_tensor.c
    mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
    mlir/test/python/dialects/sparse_tensor/dialect.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 52ca7ba8a1618f..125469f57c5f55 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -53,7 +53,8 @@ mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr);
 MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet(
     MlirContext ctx, intptr_t lvlRank,
     MlirSparseTensorLevelType const *lvlTypes, MlirAffineMap dimToLvl,
-    MlirAffineMap lvlTodim, int posWidth, int crdWidth);
+    MlirAffineMap lvlTodim, int posWidth, int crdWidth,
+    MlirAttribute explicitVal, MlirAttribute implicitVal);
 
 /// Returns the level-rank of the `sparse_tensor.encoding` attribute.
 MLIR_CAPI_EXPORTED intptr_t
@@ -85,6 +86,14 @@ mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr);
 MLIR_CAPI_EXPORTED int
 mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr);
 
+/// Returns the explicit value of the `sparse_tensor.encoding` attribute.
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirSparseTensorEncodingAttrGetExplicitVal(MlirAttribute attr);
+
+/// Returns the implicit value of the `sparse_tensor.encoding` attribute.
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirSparseTensorEncodingAttrGetImplicitVal(MlirAttribute attr);
+
 MLIR_CAPI_EXPORTED unsigned
 mlirSparseTensorEncodingAttrGetStructuredN(MlirSparseTensorLevelType lvlType);
 

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 4a9b9169ae4b86..eefa4c71bbd2ca 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -167,7 +167,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     - **soa** : only applicable to singleton levels, fuses the singleton
       level in SoA (structure of arrays) scheme.
 
-    In addition to the map, the following two fields are optional:
+    In addition to the map, the following fields are optional:
 
     - The required bitwidth for position storage (integral offsets
       into the sparse storage scheme).  A narrow width reduces the memory
@@ -183,6 +183,23 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
       coordinate over all levels).  The choices are `8`, `16`, `32`,
       `64`, or, the default, `0` to indicate a native bitwidth.
 
+    - The explicit value for the sparse tensor. If explicitVal is set,
+      then all the non-zero values in the tensor have the same explicit value.
+      The default value Attribute() indicates that it is not set. This
+      is useful for binary-valued sparse tensors whose values can either
+      be an implicit value (0 by default) or an explicit value (such as 1).
+      In this approach, we don't store explicit/implicit values, and metadata
+      (such as position and coordinate arrays) alone fully defines the original tensor.
+      This yields additional savings for the storage requirements,
+      as well as for the computational time, since we skip operating on
+      implicit values and can constant fold the explicit values where they are used.
+
+    - The implicit value for the sparse tensor. If implicitVal is set,
+      then the "zero" value in the tensor is equal to the implicit value.
+      For now, we only support `0` as the implicit value but it could be
+      extended in the future. The default value Attribute() indicates that
+      the implicit value is `0` (same type as the tensor element type).
+
     Examples:
 
     ```mlir
@@ -226,6 +243,15 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     }>
     ... tensor<8x8xf64, #DCSC> ...
 
+    // Doubly compressed sparse column storage with specific
+    // explicit and implicit values.
+    #DCSC = #sparse_tensor.encoding<{
+      map = (i, j) -> (j : compressed, i : compressed),
+      explicitVal = 1 : i64,
+      implicitVal = 0 : i64
+    }>
+    ... tensor<8x8xi64, #DCSC> ...
+
     // Block sparse row storage (2x3 blocks).
     #BSR = #sparse_tensor.encoding<{
       map = ( i, j ) ->
@@ -307,6 +333,12 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     // The required bitwidth for coordinate storage.
     "unsigned":$crdWidth,
 
+    // The required explicit value.
+    "::mlir::Attribute":$explicitVal,
+
+    // The required implicit value.
+    "::mlir::Attribute":$implicitVal,
+
     // A slice attribute for each dimension of the tensor type.
     ArrayRefParameter<
       "::mlir::sparse_tensor::SparseTensorDimSliceAttr",
@@ -319,7 +351,9 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
                      CArg<"AffineMap", "{}">:$dimToLvl,
                      CArg<"AffineMap", "{}">:$lvlToDim,
                      CArg<"unsigned", "0">:$posWidth,
-                     CArg<"unsigned", "0">:$crdWidth), [{
+                     CArg<"unsigned", "0">:$crdWidth,
+                     CArg<"::mlir::Attribute", "{}">:$explicitVal,
+                     CArg<"::mlir::Attribute", "{}">:$implicitVal), [{
       if (!dimToLvl) {
         dimToLvl = ::mlir::AffineMap::getMultiDimIdentityMap(lvlTypes.size(), $_ctxt);
       }
@@ -327,6 +361,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
         lvlToDim = ::mlir::sparse_tensor::inferLvlToDim(dimToLvl, $_ctxt);
       }
       return $_get($_ctxt, lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
+        explicitVal, implicitVal,
         ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr>{});
     }]>
   ];
@@ -353,6 +388,22 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     /// reset to the default, and all other fields inherited from `this`.
     SparseTensorEncodingAttr withoutBitWidths() const;
 
+    /// Constructs a new encoding with the given explicit value
+    /// and all other fields inherited from `this`.
+    SparseTensorEncodingAttr withExplicitVal(Attribute explicitVal) const;
+
+    /// Constructs a new encoding with the explicit value
+    /// reset to the default, and all other fields inherited from `this`.
+    SparseTensorEncodingAttr withoutExplicitVal() const;
+
+    /// Constructs a new encoding with the given implicit value
+    /// and all other fields inherited from `this`.
+    SparseTensorEncodingAttr withImplicitVal(Attribute implicitVal) const;
+
+    /// Constructs a new encoding with the implicit value
+    /// reset to the default, and all other fields inherited from `this`.
+    SparseTensorEncodingAttr withoutImplicitVal() const;
+
     /// Constructs a new encoding with the given dimSlices, and all
     /// other fields inherited from `this`.
     SparseTensorEncodingAttr withDimSlices(ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr> dimSlices) const;

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 825d89a408febe..34d99913fbd51b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -115,6 +115,22 @@ class SparseTensorType {
     return withEncoding(enc.withoutBitWidths());
   }
 
+  SparseTensorType withExplicitVal(Attribute explicitVal) const {
+    return withEncoding(enc.withExplicitVal(explicitVal));
+  }
+
+  SparseTensorType withoutExplicitVal() const {
+    return withEncoding(enc.withoutExplicitVal());
+  }
+
+  SparseTensorType withImplicitVal(Attribute implicitVal) const {
+    return withEncoding(enc.withImplicitVal(implicitVal));
+  }
+
+  SparseTensorType withoutImplicitVal() const {
+    return withEncoding(enc.withoutImplicitVal());
+  }
+
   SparseTensorType
   withDimSlices(ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
     return withEncoding(enc.withDimSlices(dimSlices));
@@ -327,6 +343,12 @@ class SparseTensorType {
   /// Returns the position-overhead bitwidth, defaulting to zero.
   unsigned getPosWidth() const { return enc ? enc.getPosWidth() : 0; }
 
+  /// Returns the explicit value, defaulting to null Attribute for unset.
+  Attribute getExplicitVal() const { return enc.getExplicitVal(); }
+
+  /// Returns the implicit value, defaulting to null Attribute for 0.
+  Attribute getImplicitVal() const { return enc.getImplicitVal(); }
+
   /// Returns the coordinate-overhead MLIR type, defaulting to `IndexType`.
   Type getCrdType() const { return enc.getCrdElemType(); }
 

diff  --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index 171faf9e008746..584981cfe99bf1 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -42,16 +42,19 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
           [](py::object cls, std::vector<MlirSparseTensorLevelType> lvlTypes,
              std::optional<MlirAffineMap> dimToLvl,
              std::optional<MlirAffineMap> lvlToDim, int posWidth, int crdWidth,
-             MlirContext context) {
+             std::optional<MlirAttribute> explicitVal,
+             std::optional<MlirAttribute> implicitVal, MlirContext context) {
             return cls(mlirSparseTensorEncodingAttrGet(
                 context, lvlTypes.size(), lvlTypes.data(),
                 dimToLvl ? *dimToLvl : MlirAffineMap{nullptr},
                 lvlToDim ? *lvlToDim : MlirAffineMap{nullptr}, posWidth,
-                crdWidth));
+                crdWidth, explicitVal ? *explicitVal : MlirAttribute{nullptr},
+                implicitVal ? *implicitVal : MlirAttribute{nullptr}));
           },
           py::arg("cls"), py::arg("lvl_types"), py::arg("dim_to_lvl"),
           py::arg("lvl_to_dim"), py::arg("pos_width"), py::arg("crd_width"),
-          py::arg("context") = py::none(),
+          py::arg("explicit_val") = py::none(),
+          py::arg("implicit_val") = py::none(), py::arg("context") = py::none(),
           "Gets a sparse_tensor.encoding from parameters.")
       .def_classmethod(
           "build_level_type",
@@ -97,6 +100,24 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
                              mlirSparseTensorEncodingAttrGetPosWidth)
       .def_property_readonly("crd_width",
                              mlirSparseTensorEncodingAttrGetCrdWidth)
+      .def_property_readonly(
+          "explicit_val",
+          [](MlirAttribute self) -> std::optional<MlirAttribute> {
+            MlirAttribute ret =
+                mlirSparseTensorEncodingAttrGetExplicitVal(self);
+            if (mlirAttributeIsNull(ret))
+              return {};
+            return ret;
+          })
+      .def_property_readonly(
+          "implicit_val",
+          [](MlirAttribute self) -> std::optional<MlirAttribute> {
+            MlirAttribute ret =
+                mlirSparseTensorEncodingAttrGetImplicitVal(self);
+            if (mlirAttributeIsNull(ret))
+              return {};
+            return ret;
+          })
       .def_property_readonly(
           "structured_n",
           [](MlirAttribute self) -> unsigned {

diff  --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
index 3ae06f220c5281..19171d64d40949 100644
--- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp
+++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
@@ -44,18 +44,20 @@ bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
   return isa<SparseTensorEncodingAttr>(unwrap(attr));
 }
 
-MlirAttribute
-mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank,
-                                MlirSparseTensorLevelType const *lvlTypes,
-                                MlirAffineMap dimToLvl, MlirAffineMap lvlToDim,
-                                int posWidth, int crdWidth) {
+MlirAttribute mlirSparseTensorEncodingAttrGet(
+    MlirContext ctx, intptr_t lvlRank,
+    MlirSparseTensorLevelType const *lvlTypes, MlirAffineMap dimToLvl,
+    MlirAffineMap lvlToDim, int posWidth, int crdWidth,
+    MlirAttribute explicitVal, MlirAttribute implicitVal) {
   SmallVector<LevelType> cppLvlTypes;
+
   cppLvlTypes.reserve(lvlRank);
   for (intptr_t l = 0; l < lvlRank; ++l)
     cppLvlTypes.push_back(static_cast<LevelType>(lvlTypes[l]));
-  return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppLvlTypes,
-                                            unwrap(dimToLvl), unwrap(lvlToDim),
-                                            posWidth, crdWidth));
+
+  return wrap(SparseTensorEncodingAttr::get(
+      unwrap(ctx), cppLvlTypes, unwrap(dimToLvl), unwrap(lvlToDim), posWidth,
+      crdWidth, unwrap(explicitVal), unwrap(implicitVal)));
 }
 
 MlirAffineMap mlirSparseTensorEncodingAttrGetDimToLvl(MlirAttribute attr) {
@@ -91,6 +93,14 @@ int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr) {
   return cast<SparseTensorEncodingAttr>(unwrap(attr)).getCrdWidth();
 }
 
+MlirAttribute mlirSparseTensorEncodingAttrGetExplicitVal(MlirAttribute attr) {
+  return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getExplicitVal());
+}
+
+MlirAttribute mlirSparseTensorEncodingAttrGetImplicitVal(MlirAttribute attr) {
+  return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getImplicitVal());
+}
+
 MlirSparseTensorLevelType mlirSparseTensorEncodingAttrBuildLvlType(
     enum MlirSparseTensorLevelFormat lvlFmt,
     const enum MlirSparseTensorLevelPropertyNondefault *properties,

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index b1d44559fa5aba..028a69da10c1e1 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -326,9 +326,9 @@ SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
 SparseTensorEncodingAttr
 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
-  return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), dimToLvl,
-                                       AffineMap(), getPosWidth(),
-                                       getCrdWidth());
+  return SparseTensorEncodingAttr::get(
+      getContext(), getLvlTypes(), dimToLvl, AffineMap(), getPosWidth(),
+      getCrdWidth(), getExplicitVal(), getImplicitVal());
 }
 
 SparseTensorEncodingAttr
@@ -344,20 +344,44 @@ SparseTensorEncodingAttr
 SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
                                         unsigned crdWidth) const {
   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
-  return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
-                                       getDimToLvl(), getLvlToDim(), posWidth,
-                                       crdWidth);
+  return SparseTensorEncodingAttr::get(
+      getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth,
+      crdWidth, getExplicitVal(), getImplicitVal());
 }
 
 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
   return withBitWidths(0, 0);
 }
 
+SparseTensorEncodingAttr
+SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const {
+  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
+  return SparseTensorEncodingAttr::get(
+      getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
+      getCrdWidth(), explicitVal, getImplicitVal());
+}
+
+SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const {
+  return withExplicitVal(Attribute());
+}
+
+SparseTensorEncodingAttr
+SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const {
+  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
+  return SparseTensorEncodingAttr::get(
+      getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
+      getCrdWidth(), getExplicitVal(), implicitVal);
+}
+
+SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const {
+  return withImplicitVal(Attribute());
+}
+
 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
     ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
-  return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
-                                       getDimToLvl(), getLvlToDim(),
-                                       getPosWidth(), getCrdWidth(), dimSlices);
+  return SparseTensorEncodingAttr::get(
+      getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
+      getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices);
 }
 
 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
@@ -553,8 +577,11 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
   AffineMap lvlToDim = {};
   unsigned posWidth = 0;
   unsigned crdWidth = 0;
+  Attribute explicitVal;
+  Attribute implicitVal;
   StringRef attrName;
-  SmallVector<StringRef, 3> keys = {"map", "posWidth", "crdWidth"};
+  SmallVector<StringRef, 5> keys = {"map", "posWidth", "crdWidth",
+                                    "explicitVal", "implicitVal"};
   while (succeeded(parser.parseOptionalKeyword(&attrName))) {
     // Detect admissible keyword.
     auto *it = find(keys, attrName);
@@ -628,6 +655,36 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
       crdWidth = intAttr.getInt();
       break;
     }
+    case 3: { // explicitVal
+      Attribute attr;
+      if (failed(parser.parseAttribute(attr)))
+        return {};
+      if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
+        explicitVal = result;
+      } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
+        explicitVal = result;
+      } else {
+        parser.emitError(parser.getNameLoc(),
+                         "expected a numeric value for explicitVal");
+        return {};
+      }
+      break;
+    }
+    case 4: { // implicitVal
+      Attribute attr;
+      if (failed(parser.parseAttribute(attr)))
+        return {};
+      if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
+        implicitVal = result;
+      } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
+        implicitVal = result;
+      } else {
+        parser.emitError(parser.getNameLoc(),
+                         "expected a numeric value for implicitVal");
+        return {};
+      }
+      break;
+    }
     } // switch
     // Only last item can omit the comma.
     if (parser.parseOptionalComma().failed())
@@ -646,7 +703,7 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
   }
   return parser.getChecked<SparseTensorEncodingAttr>(
       parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
-      dimSlices);
+      explicitVal, implicitVal, dimSlices);
 }
 
 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
@@ -666,6 +723,11 @@ void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
     printer << ", posWidth = " << getPosWidth();
   if (getCrdWidth())
     printer << ", crdWidth = " << getCrdWidth();
+  if (getExplicitVal()) {
+    printer << ", explicitVal = " << getExplicitVal();
+  }
+  if (getImplicitVal())
+    printer << ", implicitVal = " << getImplicitVal();
   printer << " }>";
 }
 
@@ -715,7 +777,8 @@ void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
 LogicalResult SparseTensorEncodingAttr::verify(
     function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes,
     AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth,
-    unsigned crdWidth, ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
+    unsigned crdWidth, Attribute explicitVal, Attribute implicitVal,
+    ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
   if (!acceptBitWidth(posWidth))
     return emitError() << "unexpected position bitwidth: " << posWidth;
   if (!acceptBitWidth(crdWidth))
@@ -831,7 +894,8 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
   // Check structural integrity.  In particular, this ensures that the
   // level-rank is coherent across all the fields.
   if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
-                    getPosWidth(), getCrdWidth(), getDimSlices())))
+                    getPosWidth(), getCrdWidth(), getExplicitVal(),
+                    getImplicitVal(), getDimSlices())))
     return failure();
   // Check integrity with tensor type specifics.  In particular, we
   // need only check that the dimension-rank of the tensor agrees with
@@ -921,9 +985,9 @@ mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
     // Ends by a unique singleton level.
     lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
   }
-  auto enc = SparseTensorEncodingAttr::get(getContext(), lvlTypes,
-                                           getDimToLvl(), getLvlToDim(),
-                                           getPosWidth(), getCrdWidth());
+  auto enc = SparseTensorEncodingAttr::get(
+      getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(),
+      getCrdWidth(), getExplicitVal(), getImplicitVal());
   return RankedTensorType::get(getDimShape(), getElementType(), enc);
 }
 
@@ -1115,7 +1179,10 @@ getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
       // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
       // value for 
diff erent bitwidth, it also avoids casting between index and
       // integer (returned by DimOp)
-      0, 0, enc.getDimSlices());
+      0, 0,
+      Attribute(), // explicitVal (irrelevant to storage specifier)
+      Attribute(), // implicitVal (irrelevant to storage specifier)
+      enc.getDimSlices());
 }
 
 StorageSpecifierType

diff  --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c
index f241e0e5c2fb56..22b7052b732aa2 100644
--- a/mlir/test/CAPI/sparse_tensor.c
+++ b/mlir/test/CAPI/sparse_tensor.c
@@ -27,7 +27,7 @@ static int testRoundtripEncoding(MlirContext ctx) {
   const char *originalAsm =
     "#sparse_tensor.encoding<{ "
     "map = [s0](d0, d1) -> (s0 : dense, d0 : compressed, d1 : compressed), "
-    "posWidth = 32, crdWidth = 64 }>";
+    "posWidth = 32, crdWidth = 64, explicitVal = 1 : i64}>";
   // clang-format on
   MlirAttribute originalAttr =
       mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString(originalAsm));
@@ -56,8 +56,21 @@ static int testRoundtripEncoding(MlirContext ctx) {
   // CHECK: crdWidth: 64
   int crdWidth = mlirSparseTensorEncodingAttrGetCrdWidth(originalAttr);
   fprintf(stderr, "crdWidth: %d\n", crdWidth);
+
+  // CHECK: explicitVal: 1 : i64
+  MlirAttribute explicitVal =
+      mlirSparseTensorEncodingAttrGetExplicitVal(originalAttr);
+  fprintf(stderr, "explicitVal: ");
+  mlirAttributeDump(explicitVal);
+  // CHECK: implicitVal: <<NULL ATTRIBUTE>>
+  MlirAttribute implicitVal =
+      mlirSparseTensorEncodingAttrGetImplicitVal(originalAttr);
+  fprintf(stderr, "implicitVal: ");
+  mlirAttributeDump(implicitVal);
+
   MlirAttribute newAttr = mlirSparseTensorEncodingAttrGet(
-      ctx, lvlRank, lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth);
+      ctx, lvlRank, lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
+      explicitVal, implicitVal);
   mlirAttributeDump(newAttr); // For debugging filecheck output.
   // CHECK: equal: 1
   fprintf(stderr, "equal: %d\n", mlirAttributeEqual(originalAttr, newAttr));

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index 66e61afd897dd1..7eeda9a9880268 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -22,6 +22,64 @@ func.func private @sparse_csr(tensor<?x?xf32, #CSR>)
 
 // -----
 
+#CSR_OnlyOnes = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed),
+  posWidth = 64,
+  crdWidth = 64,
+  explicitVal = 1.0 : f32,
+  implicitVal = 0.0 : f32
+}>
+
+// CHECK: #[[$CSR_OnlyOnes:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = 1.000000e+00 : f32, implicitVal = 0.000000e+00 : f32 }>
+// CHECK-LABEL: func private @sparse_csr(
+// CHECK-SAME: tensor<?x?xf32, #[[$CSR_OnlyOnes]]>)
+func.func private @sparse_csr(tensor<?x?xf32, #CSR_OnlyOnes>)
+
+// -----
+
+#CSR_OnlyOnes = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed),
+  explicitVal = 1.0 : f64,
+  implicitVal = 0.0 : f64
+}>
+
+// CHECK: #[[$CSR_OnlyOnes:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), explicitVal = 1.000000e+00 : f64, implicitVal = 0.000000e+00 : f64 }>
+// CHECK-LABEL: func private @sparse_csr(
+// CHECK-SAME: tensor<?x?xf64, #[[$CSR_OnlyOnes]]>)
+func.func private @sparse_csr(tensor<?x?xf64, #CSR_OnlyOnes>)
+
+// -----
+
+#CSR_OnlyOnes = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed),
+  posWidth = 64,
+  crdWidth = 64,
+  explicitVal = 1 : i32,
+  implicitVal = 0 : i32
+}>
+
+// CHECK: #[[$CSR_OnlyOnes:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = 1 : i32, implicitVal = 0 : i32 }>
+// CHECK-LABEL: func private @sparse_csr(
+// CHECK-SAME: tensor<?x?xi32, #[[$CSR_OnlyOnes]]>)
+func.func private @sparse_csr(tensor<?x?xi32, #CSR_OnlyOnes>)
+
+// -----
+
+#CSR_OnlyOnes = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed),
+  posWidth = 64,
+  crdWidth = 64,
+  explicitVal = 1 : i64,
+  implicitVal = 0 : i64
+}>
+
+// CHECK: #[[$CSR_OnlyOnes:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = 1 : i64, implicitVal = 0 : i64 }>
+// CHECK-LABEL: func private @sparse_csr(
+// CHECK-SAME: tensor<?x?xi64, #[[$CSR_OnlyOnes]]>)
+func.func private @sparse_csr(tensor<?x?xi64, #CSR_OnlyOnes>)
+
+// -----
+
 #BCSR = #sparse_tensor.encoding<{
   map = (d0, d1, d2) -> (d0 : batch, d1: dense, d2 : compressed),
 }>

diff  --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 5666d090c3d5ee..3cc4575eb3e240 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -2,6 +2,7 @@
 
 from mlir.ir import *
 from mlir.dialects import sparse_tensor as st
+import textwrap
 
 
 def run(f):
@@ -15,13 +16,18 @@ def run(f):
 def testEncodingAttr1D():
     with Context() as ctx:
         parsed = Attribute.parse(
-            "#sparse_tensor.encoding<{"
-            "  map = (d0) -> (d0 : compressed),"
-            "  posWidth = 16,"
-            "  crdWidth = 32"
-            "}>"
+            textwrap.dedent(
+                """\
+                #sparse_tensor.encoding<{
+                    map = (d0) -> (d0 : compressed),
+                    posWidth = 16,
+                    crdWidth = 32,
+                    explicitVal = 1.0 : f64
+                }>\
+            """
+            )
         )
-        # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 16, crdWidth = 32 }>
+        # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 16, crdWidth = 32, explicitVal = 1.000000e+00 : f64 }>
         print(parsed)
 
         casted = st.EncodingAttr(parsed)
@@ -38,9 +44,16 @@ def testEncodingAttr1D():
         print(f"pos_width: {casted.pos_width}")
         # CHECK: crd_width: 32
         print(f"crd_width: {casted.crd_width}")
+        # CHECK: explicit_val: 1.000000e+00
+        print(f"explicit_val: {casted.explicit_val}")
+        # CHECK: implicit_val: None
+        print(f"implicit_val: {casted.implicit_val}")
 
-        created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0)
-        # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
+        new_explicit_val = FloatAttr.get_f64(1.0)
+        created = st.EncodingAttr.get(
+            casted.lvl_types, None, None, 0, 0, new_explicit_val
+        )
+        # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), explicitVal = 1.000000e+00 : f64 }>
         print(created)
         # CHECK: created_equal: False
         print(f"created_equal: {created == casted}")
@@ -57,12 +70,16 @@ def testEncodingAttr1D():
 def testEncodingAttrStructure():
     with Context() as ctx:
         parsed = Attribute.parse(
-            "#sparse_tensor.encoding<{"
-            "  map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense,"
-            "  d1 mod 4 : structured[2, 4]),"
-            "  posWidth = 16,"
-            "  crdWidth = 32"
-            "}>"
+            textwrap.dedent(
+                """\
+                #sparse_tensor.encoding<{
+                    map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense,
+                    d1 mod 4 : structured[2, 4]),
+                    posWidth = 16,
+                    crdWidth = 32,
+                }>\
+            """
+            )
         )
         # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : structured[2, 4]), posWidth = 16, crdWidth = 32 }>
         print(parsed)
@@ -144,11 +161,15 @@ def testEncodingAttrStructure():
 def testEncodingAttr2D():
     with Context() as ctx:
         parsed = Attribute.parse(
-            "#sparse_tensor.encoding<{"
-            "  map = (d0, d1) -> (d1 : dense, d0 : compressed),"
-            "  posWidth = 8,"
-            "  crdWidth = 32"
-            "}>"
+            textwrap.dedent(
+                """\
+                #sparse_tensor.encoding<{
+                    map = (d0, d1) -> (d1 : dense, d0 : compressed),
+                    posWidth = 8,
+                    crdWidth = 32,
+                }>\
+            """
+            )
         )
         # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
         print(parsed)
@@ -187,11 +208,15 @@ def testEncodingAttrOnTensorType():
     with Context() as ctx, Location.unknown():
         encoding = st.EncodingAttr(
             Attribute.parse(
-                "#sparse_tensor.encoding<{"
-                "  map = (d0) -> (d0 : compressed), "
-                "  posWidth = 64,"
-                "  crdWidth = 32"
-                "}>"
+                textwrap.dedent(
+                    """\
+                    #sparse_tensor.encoding<{
+                        map = (d0) -> (d0 : compressed),
+                        posWidth = 64,
+                        crdWidth = 32,
+                    }>\
+                """
+                )
             )
         )
         tt = RankedTensorType.get((1024,), F32Type.get(), encoding=encoding)


        


More information about the Mlir-commits mailing list