[Mlir-commits] [mlir] [mlir][sparse] Enable explicit and implicit value in sparse encoding (PR #88975)

Yinying Li llvmlistbot at llvm.org
Thu Apr 18 13:28:09 PDT 2024


https://github.com/yinying-lisa-li updated https://github.com/llvm/llvm-project/pull/88975

>From c8fe1e1375dcbf52f2ef55e8ecd9a76415001438 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Tue, 16 Apr 2024 19:09:41 +0000
Subject: [PATCH 1/6] [mlir][sparse] Enable explicit and implicit value in
 sparse encoding

1. Explicit value means the non-zero value in a sparse tensor. If explicitVal is set, then all the non-zero values in the tensor have the same explicit value. It has the default value Attribute().
2. Implicit value means the "zero" value in a sparse tensor. For now, we only support 0 as the implicit value but it could be extended in the future. It has the default value Attribute().

Example:

#CSR = #sparse_tensor.encoding<{
  map = (d0, d1) -> (d0 : dense, d1 : compressed),
  posWidth = 64,
  crdWidth = 64,
  explicitVal = 1 : i64,
  implicitVal = 0 : i64
}>

Note: this PR tests that implicitVal could be set to other values as well. The following PR will add verifier and reject any value that's not zero for implicitVal.
---
 mlir/include/mlir-c/Dialect/SparseTensor.h    |  11 +-
 .../SparseTensor/IR/SparseTensorAttrDefs.td   |  46 +++++++-
 .../SparseTensor/IR/SparseTensorType.h        |  22 ++++
 .../Bindings/Python/DialectSparseTensor.cpp   |  27 ++++-
 mlir/lib/CAPI/Dialect/SparseTensor.cpp        |  26 +++--
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 101 +++++++++++++++---
 mlir/test/CAPI/sparse_tensor.c                |  15 ++-
 .../SparseTensor/roundtrip_encoding.mlir      |  58 ++++++++++
 .../python/dialects/sparse_tensor/dialect.py  |  16 ++-
 9 files changed, 285 insertions(+), 37 deletions(-)

diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 52ca7ba8a1618f..125469f57c5f55 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -53,7 +53,8 @@ mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr);
 MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet(
     MlirContext ctx, intptr_t lvlRank,
     MlirSparseTensorLevelType const *lvlTypes, MlirAffineMap dimToLvl,
-    MlirAffineMap lvlTodim, int posWidth, int crdWidth);
+    MlirAffineMap lvlTodim, int posWidth, int crdWidth,
+    MlirAttribute explicitVal, MlirAttribute implicitVal);
 
 /// Returns the level-rank of the `sparse_tensor.encoding` attribute.
 MLIR_CAPI_EXPORTED intptr_t
@@ -85,6 +86,14 @@ mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr);
 MLIR_CAPI_EXPORTED int
 mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr);
 
+/// Returns the explicit value of the `sparse_tensor.encoding` attribute.
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirSparseTensorEncodingAttrGetExplicitVal(MlirAttribute attr);
+
+/// Returns the implicit value of the `sparse_tensor.encoding` attribute.
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirSparseTensorEncodingAttrGetImplicitVal(MlirAttribute attr);
+
 MLIR_CAPI_EXPORTED unsigned
 mlirSparseTensorEncodingAttrGetStructuredN(MlirSparseTensorLevelType lvlType);
 
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 4a9b9169ae4b86..d297a238b6fc3c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -183,6 +183,16 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
       coordinate over all levels).  The choices are `8`, `16`, `32`,
       `64`, or, the default, `0` to indicate a native bitwidth.
 
+    - The required explicit value for the sparse tensor. If explicitVal is set,
+      then all the non-zero values in the tensor have the same explicit value.
+      The default value Attribute() indicates that it is not set.
+
+    - The required implicit value for the sparse tensor. If implicitVal is set,
+      then the "zero" value in the tensor is equal to the implicit value.
+      For now, we only support `0` as the implicit value but it could be
+      extended in the future. The default value Attribute() indicates that
+      the implicit value is `0` (same type as the tensor element type).
+
     Examples:
 
     ```mlir
@@ -226,6 +236,15 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     }>
     ... tensor<8x8xf64, #DCSC> ...
 
+    // Doubly compressed sparse column storage with specific
+    // explicit and implicit values.
+    #DCSC = #sparse_tensor.encoding<{
+      map = (i, j) -> (j : compressed, i : compressed),
+      explicitVal = 1 : i64,
+      implicitVal = 0 : i64
+    }>
+    ... tensor<8x8xi64, #DCSC> ...
+
     // Block sparse row storage (2x3 blocks).
     #BSR = #sparse_tensor.encoding<{
       map = ( i, j ) ->
@@ -307,6 +326,12 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     // The required bitwidth for coordinate storage.
     "unsigned":$crdWidth,
 
+    // The required explicit value.
+    "::mlir::Attribute":$explicitVal,
+
+    // The required implicit value.
+    "::mlir::Attribute":$implicitVal,
+
     // A slice attribute for each dimension of the tensor type.
     ArrayRefParameter<
       "::mlir::sparse_tensor::SparseTensorDimSliceAttr",
@@ -319,7 +344,9 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
                      CArg<"AffineMap", "{}">:$dimToLvl,
                      CArg<"AffineMap", "{}">:$lvlToDim,
                      CArg<"unsigned", "0">:$posWidth,
-                     CArg<"unsigned", "0">:$crdWidth), [{
+                     CArg<"unsigned", "0">:$crdWidth,
+                     CArg<"::mlir::Attribute", "{}">:$explicitVal,
+                     CArg<"::mlir::Attribute", "{}">:$implicitVal), [{
       if (!dimToLvl) {
         dimToLvl = ::mlir::AffineMap::getMultiDimIdentityMap(lvlTypes.size(), $_ctxt);
       }
@@ -327,6 +354,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
         lvlToDim = ::mlir::sparse_tensor::inferLvlToDim(dimToLvl, $_ctxt);
       }
       return $_get($_ctxt, lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
+        explicitVal, implicitVal,
         ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr>{});
     }]>
   ];
@@ -353,6 +381,22 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     /// reset to the default, and all other fields inherited from `this`.
     SparseTensorEncodingAttr withoutBitWidths() const;
 
+    /// Constructs a new encoding with the given explicit value
+    /// and all other fields inherited from `this`.
+    SparseTensorEncodingAttr withExplicitVal(Attribute explicitVal) const;
+
+    /// Constructs a new encoding with the explicit value
+    /// reset to the default, and all other fields inherited from `this`.
+    SparseTensorEncodingAttr withoutExplicitVal() const;
+
+    /// Constructs a new encoding with the given implicit value
+    /// and all other fields inherited from `this`.
+    SparseTensorEncodingAttr withImplicitVal(Attribute implicitVal) const;
+
+    /// Constructs a new encoding with the implicit value
+    /// reset to the default, and all other fields inherited from `this`.
+    SparseTensorEncodingAttr withoutImplicitVal() const;
+
     /// Constructs a new encoding with the given dimSlices, and all
     /// other fields inherited from `this`.
     SparseTensorEncodingAttr withDimSlices(ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr> dimSlices) const;
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 825d89a408febe..25f4ecad27519e 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -115,6 +115,22 @@ class SparseTensorType {
     return withEncoding(enc.withoutBitWidths());
   }
 
+  SparseTensorType withExplicitVal(Attribute explicitVal) const {
+    return withEncoding(enc.withExplicitVal(explicitVal));
+  }
+
+  SparseTensorType withoutExplicitVal() const {
+    return withEncoding(enc.withoutExplicitVal());
+  }
+
+  SparseTensorType withImplicitVal(Attribute implicitVal) const {
+    return withEncoding(enc.withImplicitVal(implicitVal));
+  }
+
+  SparseTensorType withoutImplicitVal() const {
+    return withEncoding(enc.withoutImplicitVal());
+  }
+
   SparseTensorType
   withDimSlices(ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
     return withEncoding(enc.withDimSlices(dimSlices));
@@ -327,6 +343,12 @@ class SparseTensorType {
   /// Returns the position-overhead bitwidth, defaulting to zero.
   unsigned getPosWidth() const { return enc ? enc.getPosWidth() : 0; }
 
+  /// Returns the explicit value, defaulting to empty Attribute.
+  Attribute getExplicitVal() const { return enc.getExplicitVal(); }
+
+  /// Returns the implicit value, defaulting to empty Attribute.
+  Attribute getImplicitVal() const { return enc.getImplicitVal(); }
+
   /// Returns the coordinate-overhead MLIR type, defaulting to `IndexType`.
   Type getCrdType() const { return enc.getCrdElemType(); }
 
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index 171faf9e008746..584981cfe99bf1 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -42,16 +42,19 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
           [](py::object cls, std::vector<MlirSparseTensorLevelType> lvlTypes,
              std::optional<MlirAffineMap> dimToLvl,
              std::optional<MlirAffineMap> lvlToDim, int posWidth, int crdWidth,
-             MlirContext context) {
+             std::optional<MlirAttribute> explicitVal,
+             std::optional<MlirAttribute> implicitVal, MlirContext context) {
             return cls(mlirSparseTensorEncodingAttrGet(
                 context, lvlTypes.size(), lvlTypes.data(),
                 dimToLvl ? *dimToLvl : MlirAffineMap{nullptr},
                 lvlToDim ? *lvlToDim : MlirAffineMap{nullptr}, posWidth,
-                crdWidth));
+                crdWidth, explicitVal ? *explicitVal : MlirAttribute{nullptr},
+                implicitVal ? *implicitVal : MlirAttribute{nullptr}));
           },
           py::arg("cls"), py::arg("lvl_types"), py::arg("dim_to_lvl"),
           py::arg("lvl_to_dim"), py::arg("pos_width"), py::arg("crd_width"),
-          py::arg("context") = py::none(),
+          py::arg("explicit_val") = py::none(),
+          py::arg("implicit_val") = py::none(), py::arg("context") = py::none(),
           "Gets a sparse_tensor.encoding from parameters.")
       .def_classmethod(
           "build_level_type",
@@ -97,6 +100,24 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
                              mlirSparseTensorEncodingAttrGetPosWidth)
       .def_property_readonly("crd_width",
                              mlirSparseTensorEncodingAttrGetCrdWidth)
+      .def_property_readonly(
+          "explicit_val",
+          [](MlirAttribute self) -> std::optional<MlirAttribute> {
+            MlirAttribute ret =
+                mlirSparseTensorEncodingAttrGetExplicitVal(self);
+            if (mlirAttributeIsNull(ret))
+              return {};
+            return ret;
+          })
+      .def_property_readonly(
+          "implicit_val",
+          [](MlirAttribute self) -> std::optional<MlirAttribute> {
+            MlirAttribute ret =
+                mlirSparseTensorEncodingAttrGetImplicitVal(self);
+            if (mlirAttributeIsNull(ret))
+              return {};
+            return ret;
+          })
       .def_property_readonly(
           "structured_n",
           [](MlirAttribute self) -> unsigned {
diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
index 3ae06f220c5281..19171d64d40949 100644
--- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp
+++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
@@ -44,18 +44,20 @@ bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
   return isa<SparseTensorEncodingAttr>(unwrap(attr));
 }
 
-MlirAttribute
-mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank,
-                                MlirSparseTensorLevelType const *lvlTypes,
-                                MlirAffineMap dimToLvl, MlirAffineMap lvlToDim,
-                                int posWidth, int crdWidth) {
+MlirAttribute mlirSparseTensorEncodingAttrGet(
+    MlirContext ctx, intptr_t lvlRank,
+    MlirSparseTensorLevelType const *lvlTypes, MlirAffineMap dimToLvl,
+    MlirAffineMap lvlToDim, int posWidth, int crdWidth,
+    MlirAttribute explicitVal, MlirAttribute implicitVal) {
   SmallVector<LevelType> cppLvlTypes;
+
   cppLvlTypes.reserve(lvlRank);
   for (intptr_t l = 0; l < lvlRank; ++l)
     cppLvlTypes.push_back(static_cast<LevelType>(lvlTypes[l]));
-  return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppLvlTypes,
-                                            unwrap(dimToLvl), unwrap(lvlToDim),
-                                            posWidth, crdWidth));
+
+  return wrap(SparseTensorEncodingAttr::get(
+      unwrap(ctx), cppLvlTypes, unwrap(dimToLvl), unwrap(lvlToDim), posWidth,
+      crdWidth, unwrap(explicitVal), unwrap(implicitVal)));
 }
 
 MlirAffineMap mlirSparseTensorEncodingAttrGetDimToLvl(MlirAttribute attr) {
@@ -91,6 +93,14 @@ int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr) {
   return cast<SparseTensorEncodingAttr>(unwrap(attr)).getCrdWidth();
 }
 
+MlirAttribute mlirSparseTensorEncodingAttrGetExplicitVal(MlirAttribute attr) {
+  return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getExplicitVal());
+}
+
+MlirAttribute mlirSparseTensorEncodingAttrGetImplicitVal(MlirAttribute attr) {
+  return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getImplicitVal());
+}
+
 MlirSparseTensorLevelType mlirSparseTensorEncodingAttrBuildLvlType(
     enum MlirSparseTensorLevelFormat lvlFmt,
     const enum MlirSparseTensorLevelPropertyNondefault *properties,
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 516b0943bdcfac..25c38043613daf 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -326,9 +326,9 @@ SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
 SparseTensorEncodingAttr
 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
-  return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), dimToLvl,
-                                       AffineMap(), getPosWidth(),
-                                       getCrdWidth());
+  return SparseTensorEncodingAttr::get(
+      getContext(), getLvlTypes(), dimToLvl, AffineMap(), getPosWidth(),
+      getCrdWidth(), getExplicitVal(), getImplicitVal());
 }
 
 SparseTensorEncodingAttr
@@ -344,20 +344,44 @@ SparseTensorEncodingAttr
 SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
                                         unsigned crdWidth) const {
   assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
-  return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
-                                       getDimToLvl(), getLvlToDim(), posWidth,
-                                       crdWidth);
+  return SparseTensorEncodingAttr::get(
+      getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth,
+      crdWidth, getExplicitVal(), getImplicitVal());
 }
 
 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
   return withBitWidths(0, 0);
 }
 
+SparseTensorEncodingAttr
+SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const {
+  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
+  return SparseTensorEncodingAttr::get(
+      getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
+      getCrdWidth(), explicitVal, getImplicitVal());
+}
+
+SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const {
+  return withExplicitVal(Attribute());
+}
+
+SparseTensorEncodingAttr
+SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const {
+  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
+  return SparseTensorEncodingAttr::get(
+      getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
+      getCrdWidth(), getExplicitVal(), implicitVal);
+}
+
+SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const {
+  return withImplicitVal(Attribute());
+}
+
 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
     ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
-  return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
-                                       getDimToLvl(), getLvlToDim(),
-                                       getPosWidth(), getCrdWidth(), dimSlices);
+  return SparseTensorEncodingAttr::get(
+      getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
+      getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices);
 }
 
 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
@@ -553,8 +577,11 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
   AffineMap lvlToDim = {};
   unsigned posWidth = 0;
   unsigned crdWidth = 0;
+  Attribute explicitVal;
+  Attribute implicitVal;
   StringRef attrName;
-  SmallVector<StringRef, 3> keys = {"map", "posWidth", "crdWidth"};
+  SmallVector<StringRef, 5> keys = {"map", "posWidth", "crdWidth",
+                                    "explicitVal", "implicitVal"};
   while (succeeded(parser.parseOptionalKeyword(&attrName))) {
     // Detect admissible keyword.
     auto *it = find(keys, attrName);
@@ -628,6 +655,36 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
       crdWidth = intAttr.getInt();
       break;
     }
+    case 3: { // explicitVal
+      Attribute attr;
+      if (failed(parser.parseAttribute(attr)))
+        return {};
+      if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
+        explicitVal = result;
+      } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
+        explicitVal = result;
+      } else {
+        parser.emitError(parser.getNameLoc(),
+                         "expected a numeric value for explicitVal");
+        return {};
+      }
+      break;
+    }
+    case 4: { // implicitVal
+      Attribute attr;
+      if (failed(parser.parseAttribute(attr)))
+        return {};
+      if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
+        implicitVal = result;
+      } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
+        implicitVal = result;
+      } else {
+        parser.emitError(parser.getNameLoc(),
+                         "expected a numeric value for implicitVal");
+        return {};
+      }
+      break;
+    }
     } // switch
     // Only last item can omit the comma.
     if (parser.parseOptionalComma().failed())
@@ -646,7 +703,7 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
   }
   return parser.getChecked<SparseTensorEncodingAttr>(
       parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
-      dimSlices);
+      explicitVal, implicitVal, dimSlices);
 }
 
 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
@@ -666,6 +723,11 @@ void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
     printer << ", posWidth = " << getPosWidth();
   if (getCrdWidth())
     printer << ", crdWidth = " << getCrdWidth();
+  if (getExplicitVal()) {
+    printer << ", explicitVal = " << getExplicitVal();
+  }
+  if (getImplicitVal())
+    printer << ", implicitVal = " << getImplicitVal();
   printer << " }>";
 }
 
@@ -715,7 +777,8 @@ void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
 LogicalResult SparseTensorEncodingAttr::verify(
     function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes,
     AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth,
-    unsigned crdWidth, ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
+    unsigned crdWidth, Attribute explicitVal, Attribute implicitVal,
+    ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
   if (!acceptBitWidth(posWidth))
     return emitError() << "unexpected position bitwidth: " << posWidth;
   if (!acceptBitWidth(crdWidth))
@@ -831,7 +894,8 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
   // Check structural integrity.  In particular, this ensures that the
   // level-rank is coherent across all the fields.
   if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
-                    getPosWidth(), getCrdWidth(), getDimSlices())))
+                    getPosWidth(), getCrdWidth(), getExplicitVal(),
+                    getImplicitVal(), getDimSlices())))
     return failure();
   // Check integrity with tensor type specifics.  In particular, we
   // need only check that the dimension-rank of the tensor agrees with
@@ -921,9 +985,9 @@ mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
     // Ends by a unique singleton level.
     lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
   }
-  auto enc = SparseTensorEncodingAttr::get(getContext(), lvlTypes,
-                                           getDimToLvl(), getLvlToDim(),
-                                           getPosWidth(), getCrdWidth());
+  auto enc = SparseTensorEncodingAttr::get(
+      getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(),
+      getCrdWidth(), getExplicitVal(), getImplicitVal());
   return RankedTensorType::get(getDimShape(), getElementType(), enc);
 }
 
@@ -1115,7 +1179,10 @@ getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
       // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
       // value for different bitwidth, it also avoids casting between index and
       // integer (returned by DimOp)
-      0, 0, enc.getDimSlices());
+      0, 0,
+      Attribute(), // explicitVal (irrelevant to storage specifier)
+      Attribute(), // implicitVal (irrelevant to storage specifier)
+      enc.getDimSlices());
 }
 
 StorageSpecifierType
diff --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c
index f241e0e5c2fb56..e387079d5db7a0 100644
--- a/mlir/test/CAPI/sparse_tensor.c
+++ b/mlir/test/CAPI/sparse_tensor.c
@@ -27,7 +27,7 @@ static int testRoundtripEncoding(MlirContext ctx) {
   const char *originalAsm =
     "#sparse_tensor.encoding<{ "
     "map = [s0](d0, d1) -> (s0 : dense, d0 : compressed, d1 : compressed), "
-    "posWidth = 32, crdWidth = 64 }>";
+    "posWidth = 32, crdWidth = 64, explicitVal = 1 : i64}>";
   // clang-format on
   MlirAttribute originalAttr =
       mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString(originalAsm));
@@ -56,8 +56,19 @@ static int testRoundtripEncoding(MlirContext ctx) {
   // CHECK: crdWidth: 64
   int crdWidth = mlirSparseTensorEncodingAttrGetCrdWidth(originalAttr);
   fprintf(stderr, "crdWidth: %d\n", crdWidth);
+
+  // CHECK: explicitVal: 1 : i64
+  MlirAttribute explicitVal = mlirSparseTensorEncodingAttrGetExplicitVal(originalAttr);
+  fprintf(stderr, "explicitVal: ");
+  mlirAttributeDump(explicitVal);
+  // CHECK: implicitVal: <<NULL ATTRIBUTE>>
+  MlirAttribute implicitVal =
+      mlirSparseTensorEncodingAttrGetImplicitVal(originalAttr);
+  fprintf(stderr, "implicitVal: ");
+  mlirAttributeDump(implicitVal);
+
   MlirAttribute newAttr = mlirSparseTensorEncodingAttrGet(
-      ctx, lvlRank, lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth);
+      ctx, lvlRank, lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth, explicitVal, implicitVal);
   mlirAttributeDump(newAttr); // For debugging filecheck output.
   // CHECK: equal: 1
   fprintf(stderr, "equal: %d\n", mlirAttributeEqual(originalAttr, newAttr));
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index 66e61afd897dd1..1d674036ddea7a 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -22,6 +22,64 @@ func.func private @sparse_csr(tensor<?x?xf32, #CSR>)
 
 // -----
 
+#CSR = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed),
+  posWidth = 64,
+  crdWidth = 64,
+  explicitVal = 1.0 : f32,
+  implicitVal = 1.0 : f32
+}>
+
+// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = 1.000000e+00 : f32, implicitVal = 1.000000e+00 : f32 }>
+// CHECK-LABEL: func private @sparse_csr(
+// CHECK-SAME: tensor<?x?xf32, #[[$CSR]]>)
+func.func private @sparse_csr(tensor<?x?xf32, #CSR>)
+
+// -----
+
+#CSR = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed),
+  explicitVal = 1.0 : f64,
+  implicitVal = 1.0 : f64
+}>
+
+// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), explicitVal = 1.000000e+00 : f64, implicitVal = 1.000000e+00 : f64 }>
+// CHECK-LABEL: func private @sparse_csr(
+// CHECK-SAME: tensor<?x?xf64, #[[$CSR]]>)
+func.func private @sparse_csr(tensor<?x?xf64, #CSR>)
+
+// -----
+
+#CSR = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed),
+  posWidth = 64,
+  crdWidth = 64,
+  explicitVal = 1 : i32,
+  implicitVal = 1 : i32
+}>
+
+// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = 1 : i32, implicitVal = 1 : i32 }>
+// CHECK-LABEL: func private @sparse_csr(
+// CHECK-SAME: tensor<?x?xi32, #[[$CSR]]>)
+func.func private @sparse_csr(tensor<?x?xi32, #CSR>)
+
+// -----
+
+#CSR = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed),
+  posWidth = 64,
+  crdWidth = 64,
+  explicitVal = 1 : i64,
+  implicitVal = 1 : i64
+}>
+
+// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = 1 : i64, implicitVal = 1 : i64 }>
+// CHECK-LABEL: func private @sparse_csr(
+// CHECK-SAME: tensor<?x?xi64, #[[$CSR]]>)
+func.func private @sparse_csr(tensor<?x?xi64, #CSR>)
+
+// -----
+
 #BCSR = #sparse_tensor.encoding<{
   map = (d0, d1, d2) -> (d0 : batch, d1: dense, d2 : compressed),
 }>
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 5666d090c3d5ee..06e0a253861db8 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -18,10 +18,11 @@ def testEncodingAttr1D():
             "#sparse_tensor.encoding<{"
             "  map = (d0) -> (d0 : compressed),"
             "  posWidth = 16,"
-            "  crdWidth = 32"
+            "  crdWidth = 32,"
+            "  explicitVal = 1.0 : f64"
             "}>"
         )
-        # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 16, crdWidth = 32 }>
+        # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 16, crdWidth = 32, explicitVal = 1.000000e+00 : f64 }>
         print(parsed)
 
         casted = st.EncodingAttr(parsed)
@@ -38,9 +39,14 @@ def testEncodingAttr1D():
         print(f"pos_width: {casted.pos_width}")
         # CHECK: crd_width: 32
         print(f"crd_width: {casted.crd_width}")
-
-        created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0)
-        # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
+        # CHECK: explicit_val: 1.000000e+00
+        print(f"explicit_val: {casted.explicit_val}")
+        # CHECK: implicit_val: None
+        print(f"implicit_val: {casted.implicit_val}")
+
+        new_explicit_val = FloatAttr.get_f64(1.0)
+        created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0, new_explicit_val)
+        # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), explicitVal = 1.000000e+00 : f64 }>
         print(created)
         # CHECK: created_equal: False
         print(f"created_equal: {created == casted}")

>From 8c3f53deb0f9c8d1d284d9de52cab6ec071ea70a Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Tue, 16 Apr 2024 19:32:45 +0000
Subject: [PATCH 2/6] modify comment

---
 .../mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td    | 6 +++---
 .../include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h | 4 ++--
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index d297a238b6fc3c..472ebfacdea464 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -167,7 +167,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     - **soa** : only applicable to singleton levels, fuses the singleton
       level in SoA (structure of arrays) scheme.
 
-    In addition to the map, the following two fields are optional:
+    In addition to the map, the following four fields are optional:
 
     - The required bitwidth for position storage (integral offsets
       into the sparse storage scheme).  A narrow width reduces the memory
@@ -183,11 +183,11 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
       coordinate over all levels).  The choices are `8`, `16`, `32`,
       `64`, or, the default, `0` to indicate a native bitwidth.
 
-    - The required explicit value for the sparse tensor. If explicitVal is set,
+    - The explicit value for the sparse tensor. If explicitVal is set,
       then all the non-zero values in the tensor have the same explicit value.
       The default value Attribute() indicates that it is not set.
 
-    - The required implicit value for the sparse tensor. If implicitVal is set,
+    - The implicit value for the sparse tensor. If implicitVal is set,
       then the "zero" value in the tensor is equal to the implicit value.
       For now, we only support `0` as the implicit value but it could be
       extended in the future. The default value Attribute() indicates that
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 25f4ecad27519e..34d99913fbd51b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -343,10 +343,10 @@ class SparseTensorType {
   /// Returns the position-overhead bitwidth, defaulting to zero.
   unsigned getPosWidth() const { return enc ? enc.getPosWidth() : 0; }
 
-  /// Returns the explicit value, defaulting to empty Attribute.
+  /// Returns the explicit value, defaulting to null Attribute for unset.
   Attribute getExplicitVal() const { return enc.getExplicitVal(); }
 
-  /// Returns the implicit value, defaulting to empty Attribute.
+  /// Returns the implicit value, defaulting to null Attribute for 0.
   Attribute getImplicitVal() const { return enc.getImplicitVal(); }
 
   /// Returns the coordinate-overhead MLIR type, defaulting to `IndexType`.

>From 77a877e4d2f6c22281841fb1b1ba2cf8db6f617e Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Tue, 16 Apr 2024 19:53:55 +0000
Subject: [PATCH 3/6] format

---
 mlir/test/CAPI/sparse_tensor.c                     | 6 ++++--
 mlir/test/python/dialects/sparse_tensor/dialect.py | 4 +++-
 2 files changed, 7 insertions(+), 3 deletions(-)

diff --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c
index e387079d5db7a0..22b7052b732aa2 100644
--- a/mlir/test/CAPI/sparse_tensor.c
+++ b/mlir/test/CAPI/sparse_tensor.c
@@ -58,7 +58,8 @@ static int testRoundtripEncoding(MlirContext ctx) {
   fprintf(stderr, "crdWidth: %d\n", crdWidth);
 
   // CHECK: explicitVal: 1 : i64
-  MlirAttribute explicitVal = mlirSparseTensorEncodingAttrGetExplicitVal(originalAttr);
+  MlirAttribute explicitVal =
+      mlirSparseTensorEncodingAttrGetExplicitVal(originalAttr);
   fprintf(stderr, "explicitVal: ");
   mlirAttributeDump(explicitVal);
   // CHECK: implicitVal: <<NULL ATTRIBUTE>>
@@ -68,7 +69,8 @@ static int testRoundtripEncoding(MlirContext ctx) {
   mlirAttributeDump(implicitVal);
 
   MlirAttribute newAttr = mlirSparseTensorEncodingAttrGet(
-      ctx, lvlRank, lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth, explicitVal, implicitVal);
+      ctx, lvlRank, lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
+      explicitVal, implicitVal);
   mlirAttributeDump(newAttr); // For debugging filecheck output.
   // CHECK: equal: 1
   fprintf(stderr, "equal: %d\n", mlirAttributeEqual(originalAttr, newAttr));
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 06e0a253861db8..21f28c7437c43f 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -45,7 +45,9 @@ def testEncodingAttr1D():
         print(f"implicit_val: {casted.implicit_val}")
 
         new_explicit_val = FloatAttr.get_f64(1.0)
-        created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0, new_explicit_val)
+        created = st.EncodingAttr.get(
+            casted.lvl_types, None, None, 0, 0, new_explicit_val
+        )
         # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), explicitVal = 1.000000e+00 : f64 }>
         print(created)
         # CHECK: created_equal: False

>From 0a848720601e842423cb5166308f33e76b9fcbeb Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Tue, 16 Apr 2024 21:00:14 +0000
Subject: [PATCH 4/6] change test for implicit value

---
 .../Dialect/SparseTensor/roundtrip_encoding.mlir | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index 1d674036ddea7a..dba95a0b29a46f 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -27,10 +27,10 @@ func.func private @sparse_csr(tensor<?x?xf32, #CSR>)
   posWidth = 64,
   crdWidth = 64,
   explicitVal = 1.0 : f32,
-  implicitVal = 1.0 : f32
+  implicitVal = 0.0 : f32
 }>
 
-// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = 1.000000e+00 : f32, implicitVal = 1.000000e+00 : f32 }>
+// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = 1.000000e+00 : f32, implicitVal = 0.000000e+00 : f32 }>
 // CHECK-LABEL: func private @sparse_csr(
 // CHECK-SAME: tensor<?x?xf32, #[[$CSR]]>)
 func.func private @sparse_csr(tensor<?x?xf32, #CSR>)
@@ -40,10 +40,10 @@ func.func private @sparse_csr(tensor<?x?xf32, #CSR>)
 #CSR = #sparse_tensor.encoding<{
   map = (d0, d1) -> (d0 : dense, d1 : compressed),
   explicitVal = 1.0 : f64,
-  implicitVal = 1.0 : f64
+  implicitVal = 0.0 : f64
 }>
 
-// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), explicitVal = 1.000000e+00 : f64, implicitVal = 1.000000e+00 : f64 }>
+// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), explicitVal = 1.000000e+00 : f64, implicitVal = 0.000000e+00 : f64 }>
 // CHECK-LABEL: func private @sparse_csr(
 // CHECK-SAME: tensor<?x?xf64, #[[$CSR]]>)
 func.func private @sparse_csr(tensor<?x?xf64, #CSR>)
@@ -55,10 +55,10 @@ func.func private @sparse_csr(tensor<?x?xf64, #CSR>)
   posWidth = 64,
   crdWidth = 64,
   explicitVal = 1 : i32,
-  implicitVal = 1 : i32
+  implicitVal = 0 : i32
 }>
 
-// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = 1 : i32, implicitVal = 1 : i32 }>
+// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = 1 : i32, implicitVal = 0 : i32 }>
 // CHECK-LABEL: func private @sparse_csr(
 // CHECK-SAME: tensor<?x?xi32, #[[$CSR]]>)
 func.func private @sparse_csr(tensor<?x?xi32, #CSR>)
@@ -70,10 +70,10 @@ func.func private @sparse_csr(tensor<?x?xi32, #CSR>)
   posWidth = 64,
   crdWidth = 64,
   explicitVal = 1 : i64,
-  implicitVal = 1 : i64
+  implicitVal = 0 : i64
 }>
 
-// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = 1 : i64, implicitVal = 1 : i64 }>
+// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = 1 : i64, implicitVal = 0 : i64 }>
 // CHECK-LABEL: func private @sparse_csr(
 // CHECK-SAME: tensor<?x?xi64, #[[$CSR]]>)
 func.func private @sparse_csr(tensor<?x?xi64, #CSR>)

>From b6d0fd043cf9a6e0b3a79db1b7dbc3689fe2f27d Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 18 Apr 2024 20:07:22 +0000
Subject: [PATCH 5/6] address comments

---
 .../SparseTensor/IR/SparseTensorAttrDefs.td   |  7 ++-
 .../SparseTensor/roundtrip_encoding.mlir      | 32 +++++------
 .../python/dialects/sparse_tensor/dialect.py  | 57 ++++++++++++-------
 3 files changed, 56 insertions(+), 40 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 472ebfacdea464..04173f397cbd61 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -167,7 +167,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     - **soa** : only applicable to singleton levels, fuses the singleton
       level in SoA (structure of arrays) scheme.
 
-    In addition to the map, the following four fields are optional:
+    In addition to the map, the following fields are optional:
 
     - The required bitwidth for position storage (integral offsets
       into the sparse storage scheme).  A narrow width reduces the memory
@@ -185,7 +185,10 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
 
     - The explicit value for the sparse tensor. If explicitVal is set,
       then all the non-zero values in the tensor have the same explicit value.
-      The default value Attribute() indicates that it is not set.
+      The default value Attribute() indicates that it is not set. This
+      is useful for binary-valued tensors whose values could only
+      be 0 or 1, as we can set the explicit value to be 1 instead of
+      storing the values array.
 
     - The implicit value for the sparse tensor. If implicitVal is set,
       then the "zero" value in the tensor is equal to the implicit value.
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index dba95a0b29a46f..7eeda9a9880268 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -22,7 +22,7 @@ func.func private @sparse_csr(tensor<?x?xf32, #CSR>)
 
 // -----
 
-#CSR = #sparse_tensor.encoding<{
+#CSR_OnlyOnes = #sparse_tensor.encoding<{
   map = (d0, d1) -> (d0 : dense, d1 : compressed),
   posWidth = 64,
   crdWidth = 64,
@@ -30,27 +30,27 @@ func.func private @sparse_csr(tensor<?x?xf32, #CSR>)
   implicitVal = 0.0 : f32
 }>
 
-// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = 1.000000e+00 : f32, implicitVal = 0.000000e+00 : f32 }>
+// CHECK: #[[$CSR_OnlyOnes:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = 1.000000e+00 : f32, implicitVal = 0.000000e+00 : f32 }>
 // CHECK-LABEL: func private @sparse_csr(
-// CHECK-SAME: tensor<?x?xf32, #[[$CSR]]>)
-func.func private @sparse_csr(tensor<?x?xf32, #CSR>)
+// CHECK-SAME: tensor<?x?xf32, #[[$CSR_OnlyOnes]]>)
+func.func private @sparse_csr(tensor<?x?xf32, #CSR_OnlyOnes>)
 
 // -----
 
-#CSR = #sparse_tensor.encoding<{
+#CSR_OnlyOnes = #sparse_tensor.encoding<{
   map = (d0, d1) -> (d0 : dense, d1 : compressed),
   explicitVal = 1.0 : f64,
   implicitVal = 0.0 : f64
 }>
 
-// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), explicitVal = 1.000000e+00 : f64, implicitVal = 0.000000e+00 : f64 }>
+// CHECK: #[[$CSR_OnlyOnes:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), explicitVal = 1.000000e+00 : f64, implicitVal = 0.000000e+00 : f64 }>
 // CHECK-LABEL: func private @sparse_csr(
-// CHECK-SAME: tensor<?x?xf64, #[[$CSR]]>)
-func.func private @sparse_csr(tensor<?x?xf64, #CSR>)
+// CHECK-SAME: tensor<?x?xf64, #[[$CSR_OnlyOnes]]>)
+func.func private @sparse_csr(tensor<?x?xf64, #CSR_OnlyOnes>)
 
 // -----
 
-#CSR = #sparse_tensor.encoding<{
+#CSR_OnlyOnes = #sparse_tensor.encoding<{
   map = (d0, d1) -> (d0 : dense, d1 : compressed),
   posWidth = 64,
   crdWidth = 64,
@@ -58,14 +58,14 @@ func.func private @sparse_csr(tensor<?x?xf64, #CSR>)
   implicitVal = 0 : i32
 }>
 
-// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = 1 : i32, implicitVal = 0 : i32 }>
+// CHECK: #[[$CSR_OnlyOnes:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = 1 : i32, implicitVal = 0 : i32 }>
 // CHECK-LABEL: func private @sparse_csr(
-// CHECK-SAME: tensor<?x?xi32, #[[$CSR]]>)
-func.func private @sparse_csr(tensor<?x?xi32, #CSR>)
+// CHECK-SAME: tensor<?x?xi32, #[[$CSR_OnlyOnes]]>)
+func.func private @sparse_csr(tensor<?x?xi32, #CSR_OnlyOnes>)
 
 // -----
 
-#CSR = #sparse_tensor.encoding<{
+#CSR_OnlyOnes = #sparse_tensor.encoding<{
   map = (d0, d1) -> (d0 : dense, d1 : compressed),
   posWidth = 64,
   crdWidth = 64,
@@ -73,10 +73,10 @@ func.func private @sparse_csr(tensor<?x?xi32, #CSR>)
   implicitVal = 0 : i64
 }>
 
-// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = 1 : i64, implicitVal = 0 : i64 }>
+// CHECK: #[[$CSR_OnlyOnes:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = 1 : i64, implicitVal = 0 : i64 }>
 // CHECK-LABEL: func private @sparse_csr(
-// CHECK-SAME: tensor<?x?xi64, #[[$CSR]]>)
-func.func private @sparse_csr(tensor<?x?xi64, #CSR>)
+// CHECK-SAME: tensor<?x?xi64, #[[$CSR_OnlyOnes]]>)
+func.func private @sparse_csr(tensor<?x?xi64, #CSR_OnlyOnes>)
 
 // -----
 
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 21f28c7437c43f..713a914f4b36fe 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -2,6 +2,7 @@
 
 from mlir.ir import *
 from mlir.dialects import sparse_tensor as st
+import textwrap
 
 
 def run(f):
@@ -15,12 +16,15 @@ def run(f):
 def testEncodingAttr1D():
     with Context() as ctx:
         parsed = Attribute.parse(
-            "#sparse_tensor.encoding<{"
-            "  map = (d0) -> (d0 : compressed),"
-            "  posWidth = 16,"
-            "  crdWidth = 32,"
-            "  explicitVal = 1.0 : f64"
-            "}>"
+            textwrap.dedent("""\
+                #sparse_tensor.encoding<{
+                    map = (d0) -> (d0 : compressed),
+                    posWidth = 16,
+                    crdWidth = 32,
+                    explicitVal = 1.0 : f64
+                }>\
+            """
+            )
         )
         # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 16, crdWidth = 32, explicitVal = 1.000000e+00 : f64 }>
         print(parsed)
@@ -65,12 +69,15 @@ def testEncodingAttr1D():
 def testEncodingAttrStructure():
     with Context() as ctx:
         parsed = Attribute.parse(
-            "#sparse_tensor.encoding<{"
-            "  map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense,"
-            "  d1 mod 4 : structured[2, 4]),"
-            "  posWidth = 16,"
-            "  crdWidth = 32"
-            "}>"
+            textwrap.dedent("""\
+                #sparse_tensor.encoding<{
+                    map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense,
+                    d1 mod 4 : structured[2, 4]),
+                    posWidth = 16,
+                    crdWidth = 32,
+                }>\
+            """
+            )
         )
         # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : structured[2, 4]), posWidth = 16, crdWidth = 32 }>
         print(parsed)
@@ -152,11 +159,14 @@ def testEncodingAttrStructure():
 def testEncodingAttr2D():
     with Context() as ctx:
         parsed = Attribute.parse(
-            "#sparse_tensor.encoding<{"
-            "  map = (d0, d1) -> (d1 : dense, d0 : compressed),"
-            "  posWidth = 8,"
-            "  crdWidth = 32"
-            "}>"
+            textwrap.dedent("""\
+                #sparse_tensor.encoding<{
+                    map = (d0, d1) -> (d1 : dense, d0 : compressed),
+                    posWidth = 8,
+                    crdWidth = 32,
+                }>\
+            """
+            )
         )
         # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
         print(parsed)
@@ -195,11 +205,14 @@ def testEncodingAttrOnTensorType():
     with Context() as ctx, Location.unknown():
         encoding = st.EncodingAttr(
             Attribute.parse(
-                "#sparse_tensor.encoding<{"
-                "  map = (d0) -> (d0 : compressed), "
-                "  posWidth = 64,"
-                "  crdWidth = 32"
-                "}>"
+                textwrap.dedent("""\
+                    #sparse_tensor.encoding<{
+                        map = (d0) -> (d0 : compressed),
+                        posWidth = 64,
+                        crdWidth = 32,
+                    }>\
+                """
+                )
             )
         )
         tt = RankedTensorType.get((1024,), F32Type.get(), encoding=encoding)

>From 13315fa5a68132fe00ec0bbadb7ee3efceb69afb Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 18 Apr 2024 20:27:47 +0000
Subject: [PATCH 6/6] python format

---
 mlir/test/python/dialects/sparse_tensor/dialect.py | 12 ++++++++----
 1 file changed, 8 insertions(+), 4 deletions(-)

diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 713a914f4b36fe..3cc4575eb3e240 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -16,7 +16,8 @@ def run(f):
 def testEncodingAttr1D():
     with Context() as ctx:
         parsed = Attribute.parse(
-            textwrap.dedent("""\
+            textwrap.dedent(
+                """\
                 #sparse_tensor.encoding<{
                     map = (d0) -> (d0 : compressed),
                     posWidth = 16,
@@ -69,7 +70,8 @@ def testEncodingAttr1D():
 def testEncodingAttrStructure():
     with Context() as ctx:
         parsed = Attribute.parse(
-            textwrap.dedent("""\
+            textwrap.dedent(
+                """\
                 #sparse_tensor.encoding<{
                     map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense,
                     d1 mod 4 : structured[2, 4]),
@@ -159,7 +161,8 @@ def testEncodingAttrStructure():
 def testEncodingAttr2D():
     with Context() as ctx:
         parsed = Attribute.parse(
-            textwrap.dedent("""\
+            textwrap.dedent(
+                """\
                 #sparse_tensor.encoding<{
                     map = (d0, d1) -> (d1 : dense, d0 : compressed),
                     posWidth = 8,
@@ -205,7 +208,8 @@ def testEncodingAttrOnTensorType():
     with Context() as ctx, Location.unknown():
         encoding = st.EncodingAttr(
             Attribute.parse(
-                textwrap.dedent("""\
+                textwrap.dedent(
+                    """\
                     #sparse_tensor.encoding<{
                         map = (d0) -> (d0 : compressed),
                         posWidth = 64,



More information about the Mlir-commits mailing list