[Mlir-commits] [mlir] 885a1f4 - [mlir][sparse] support parsing slices in sparse tensor encoding attribute

Peiming Liu llvmlistbot at llvm.org
Thu Jan 12 14:35:30 PST 2023


Author: Peiming Liu
Date: 2023-01-12T22:35:24Z
New Revision: 885a1f431621c0a2bc9b17f7dae401d62646baae

URL: https://github.com/llvm/llvm-project/commit/885a1f431621c0a2bc9b17f7dae401d62646baae
DIFF: https://github.com/llvm/llvm-project/commit/885a1f431621c0a2bc9b17f7dae401d62646baae.diff

LOG: [mlir][sparse] support parsing slices in sparse tensor encoding attribute

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D140712

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
    mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 37691253e58d6..43c493c1e0f56 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -14,15 +14,69 @@ include "mlir/IR/EnumAttr.td"
 include "mlir/Dialect/SparseTensor/IR/SparseTensorBase.td"
 include "mlir/IR/TensorEncoding.td"
 
-//===----------------------------------------------------------------------===//
-// Sparse Tensor Type Encoding Attribute
-//===----------------------------------------------------------------------===//
-
 // All of the Tensor attributes will extend this class.
 class SparseTensor_Attr<string name,
                         list<Trait> traits = []>
     : AttrDef<SparseTensor_Dialect, name, traits>;
 
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Dimension Slice Attribute.
+//===----------------------------------------------------------------------===//
+
+def SparseTensorDimSliceAttr : SparseTensor_Attr<"SparseTensorDimSlice", []> {
+  let mnemonic = "slice";
+
+  let description = [{
+    An attribute to encode slice information of a sparse tensor on a particular
+    dimension (a tuple of offset, size, stride).
+  }];
+
+  let parameters = (
+    ins
+    "int64_t" : $offset,
+    "int64_t" : $size,
+    "int64_t" : $stride
+  );
+
+  let extraClassDeclaration = [{
+    /// Special value for dynamic offset/size/stride.
+    static constexpr int64_t kDynamic = -1;
+
+    static bool isDynamic(int64_t v) {
+      return v == kDynamic;
+    }
+
+    std::optional<uint64_t> getStaticOffset() const {
+      if (isDynamic(getOffset()))
+        return std::nullopt;
+      return static_cast<uint64_t>(getOffset());
+    };
+
+    std::optional<uint64_t> getStaticStride() const {
+      if (isDynamic(getStride()))
+        return std::nullopt;
+      return static_cast<uint64_t>(getStride());
+    }
+
+    std::optional<uint64_t> getStaticSize() const {
+      if (isDynamic(getSize()))
+        return std::nullopt;
+      return static_cast<uint64_t>(getSize());
+    }
+
+    bool isCompletelyDynamic() const {
+      return isDynamic(getOffset()) && isDynamic(getStride()) && isDynamic(getSize());
+    };
+  }];
+
+  let genVerifyDecl = 1;
+  let hasCustomAssemblyFormat = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Type Encoding Attribute.
+//===----------------------------------------------------------------------===//
+
 // Sparse tensor encoding attribute.
 def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
          [ DeclareAttrInterfaceMethods<VerifiableTensorEncoding> ] > {
@@ -103,6 +157,9 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
       choices are `8`, `16`, `32`, `64`, or, the default, `0` to indicate a
       native bit width.
 
+    - An optional array of SparseTensorDimSliceAttr, which specifies how the sparse
+      tensor is partitioned on each level.
+
     Examples:
 
     ```mlir
@@ -142,6 +199,15 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
       higherOrdering = affine_map<(i, j)[c] -> (c * 4 * i, i, j)>
     }>
     ... tensor<?x?xf64, #ELL> ...
+
+    // CSR slice (offset = 0, size = 4, stride = 1 on the first dimension;
+    // offset = 0, size = 8, and a dynamic stride on the second dimension).
+    #CSR_SLICE = #sparse_tensor.encoding<{
+      dimLevelType = [ "dense", "compressed" ],
+      slice = [ (0, 4, 1), (0, 8, ?) ]
+    }>
+    ... tensor<?x?xf64, #CSC_SLICE> ...
+
     ```
   }];
 
@@ -160,9 +226,29 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     // The required bit width for pointer storage.
     "unsigned":$pointerBitWidth,
     // The required bit width for index storage.
-    "unsigned":$indexBitWidth
+    "unsigned":$indexBitWidth,
+    // A dimension level type for each dimension of the tensor type.
+    ArrayRefParameter<
+      "::mlir::sparse_tensor::SparseTensorDimSliceAttr",
+      "per dimension slice metadata"
+      >: $dimSlices
   );
 
+  let builders = [
+    AttrBuilder<(ins "ArrayRef<::mlir::sparse_tensor::DimLevelType>":$dimLevelType,
+                     "AffineMap":$dimOrdering,
+                     "AffineMap":$higherOrdering,
+                     "unsigned":$pointerBitWidth,
+                     "unsigned":$indexBitWidth), [{
+      return $_get($_ctxt, dimLevelType,
+                         dimOrdering,
+                         higherOrdering,
+                         pointerBitWidth,
+                         indexBitWidth,
+                         ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr>{});
+    }]>
+  ];
+
   let extraClassDeclaration = [{
     /// Returns the type for pointer storage based on pointerBitWidth
     Type getPointerType() const;
@@ -179,6 +265,17 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
 
     /// Return true if the encoding has an identity dimension ordering.
     bool hasIdDimOrdering() const;
+
+    bool isSlice() const {
+      return !getDimSlices().empty();
+    }
+
+    std::optional<uint64_t> getStaticDimSliceOffset(unsigned dim) const;
+    std::optional<uint64_t> getStaticDimSliceSize(unsigned dim) const;
+    std::optional<uint64_t> getStaticDimSliceStride(unsigned dim) const;
+    std::optional<uint64_t> getStaticLvlSliceOffset(unsigned lvl) const;
+    std::optional<uint64_t> getStaticLvlSliceSize(unsigned lvl) const;
+    std::optional<uint64_t> getStaticLvlSliceStride(unsigned lvl) const;
   }];
 
   let genVerifyDecl = 1;
@@ -186,7 +283,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
 }
 
 //===----------------------------------------------------------------------===//
-// Sparse Tensor Storage Specifier Enum Attribute
+// Sparse Tensor Storage Specifier Enum Attribute.
 //===----------------------------------------------------------------------===//
 
 // The C++ enum for Storage Specifier kind.
@@ -209,7 +306,7 @@ def SparseTensorStorageSpecifierKindAttr
 }
 
 //===----------------------------------------------------------------------===//
-// Sparse Tensor Traits
+// Sparse Tensor Traits.
 //===----------------------------------------------------------------------===//
 
 def IsSparseTensorPred

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 59a4b1360b473..b15c9c5386871 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -45,6 +45,62 @@ static bool acceptBitWidth(unsigned bitWidth) {
   }
 }
 
+void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
+  printer << "(";
+  printer << (getStaticOffset() ? std::to_string(*getStaticOffset()) : "?");
+  printer << ", ";
+  printer << (getStaticSize() ? std::to_string(*getStaticSize()) : "?");
+  printer << ", ";
+  printer << (getStaticStride() ? std::to_string(*getStaticStride()) : "?");
+  printer << ")";
+}
+
+static ParseResult parseOptionalStaticSlice(int64_t &result,
+                                            AsmParser &parser) {
+  auto parseResult = parser.parseOptionalInteger(result);
+  if (parseResult.has_value()) {
+    if (parseResult.value().succeeded() && result < 0) {
+      parser.emitError(
+          parser.getCurrentLocation(),
+          "expect positive value or ? for slice offset/size/stride");
+      return failure();
+    }
+    return parseResult.value();
+  }
+
+  // Else, and '?' which represented dynamic slice
+  result = SparseTensorDimSliceAttr::kDynamic;
+  return parser.parseQuestion();
+}
+
+Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
+  int64_t offset = -1, size = -1, stride = -1;
+
+  if (failed(parser.parseLParen()) ||
+      failed(parseOptionalStaticSlice(offset, parser)) ||
+      failed(parser.parseComma()) ||
+      failed(parseOptionalStaticSlice(size, parser)) ||
+      failed(parser.parseComma()) ||
+      failed(parseOptionalStaticSlice(stride, parser)) ||
+      failed(parser.parseRParen()))
+    return {};
+
+  return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(),
+                                                     offset, size, stride);
+}
+
+LogicalResult
+SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                                 int64_t offset, int64_t size, int64_t stride) {
+  if ((offset == SparseTensorDimSliceAttr::kDynamic || offset >= 0) &&
+      (size == SparseTensorDimSliceAttr::kDynamic || size > 0) &&
+      (stride == SparseTensorDimSliceAttr::kDynamic || stride > 0)) {
+    return success();
+  }
+  return emitError()
+         << "expect positive value or ? for slice offset/size/stride";
+}
+
 Type SparseTensorEncodingAttr::getPointerType() const {
   unsigned ptrWidth = getPointerBitWidth();
   Type indexType = IndexType::get(getContext());
@@ -71,24 +127,70 @@ bool SparseTensorEncodingAttr::hasIdDimOrdering() const {
   return !getDimOrdering() || getDimOrdering().isIdentity();
 }
 
+std::optional<uint64_t>
+SparseTensorEncodingAttr::getStaticDimSliceOffset(unsigned dim) const {
+  return getDimSlices()[dim].getStaticOffset();
+}
+
+std::optional<uint64_t>
+SparseTensorEncodingAttr::getStaticDimSliceSize(unsigned dim) const {
+  return getDimSlices()[dim].getStaticSize();
+}
+
+std::optional<uint64_t>
+SparseTensorEncodingAttr::getStaticDimSliceStride(unsigned dim) const {
+  return getDimSlices()[dim].getStaticStride();
+}
+
+std::optional<uint64_t>
+SparseTensorEncodingAttr::getStaticLvlSliceOffset(unsigned lvl) const {
+  return getStaticDimSliceOffset(toOrigDim(*this, lvl));
+}
+
+std::optional<uint64_t>
+SparseTensorEncodingAttr::getStaticLvlSliceSize(unsigned lvl) const {
+  return getStaticDimSliceSize(toOrigDim(*this, lvl));
+}
+
+std::optional<uint64_t>
+SparseTensorEncodingAttr::getStaticLvlSliceStride(unsigned lvl) const {
+  return getStaticDimSliceStride(toOrigDim(*this, lvl));
+}
+
 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
-  if (failed(parser.parseLess()))
-    return {};
-  // Parse the data as a dictionary.
-  DictionaryAttr dict;
-  if (failed(parser.parseAttribute(dict)))
-    return {};
-  if (failed(parser.parseGreater()))
-    return {};
+#define RETURN_ON_FAIL(stmt)                                                   \
+  if (failed(stmt)) {                                                          \
+    return {};                                                                 \
+  }
+
+  RETURN_ON_FAIL(parser.parseLess())
+  RETURN_ON_FAIL(parser.parseLBrace())
+
   // Process the data from the parsed dictionary value into struct-like data.
   SmallVector<DimLevelType> dlt;
+  SmallVector<SparseTensorDimSliceAttr> slices;
   AffineMap dimOrd = {};
   AffineMap higherOrd = {};
   unsigned ptr = 0;
   unsigned ind = 0;
-  for (const NamedAttribute &attr : dict) {
-    if (attr.getName() == "dimLevelType") {
-      auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
+
+  StringRef attrName;
+  // Exactly 6 keys.
+  SmallVector<StringRef, 6> keys = {"dimLevelType",   "dimOrdering",
+                                    "higherOrdering", "pointerBitWidth",
+                                    "indexBitWidth",  "slice"};
+  while (succeeded(parser.parseOptionalKeyword(&attrName))) {
+    if (!llvm::is_contained(keys, attrName)) {
+      parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName;
+      return {};
+    }
+
+    // Consume the `=` after keys
+    RETURN_ON_FAIL(parser.parseEqual())
+    if (attrName == "dimLevelType") {
+      Attribute attr;
+      RETURN_ON_FAIL(parser.parseAttribute(attr));
+      auto arrayAttr = attr.dyn_cast<ArrayAttr>();
       if (!arrayAttr) {
         parser.emitError(parser.getNameLoc(),
                          "expected an array for dimension level types");
@@ -127,47 +229,80 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
           return {};
         }
       }
-    } else if (attr.getName() == "dimOrdering") {
-      auto affineAttr = attr.getValue().dyn_cast<AffineMapAttr>();
+    } else if (attrName == "dimOrdering") {
+      Attribute attr;
+      RETURN_ON_FAIL(parser.parseAttribute(attr))
+
+      auto affineAttr = attr.dyn_cast<AffineMapAttr>();
       if (!affineAttr) {
         parser.emitError(parser.getNameLoc(),
                          "expected an affine map for dimension ordering");
         return {};
       }
       dimOrd = affineAttr.getValue();
-    } else if (attr.getName() == "higherOrdering") {
-      auto affineAttr = attr.getValue().dyn_cast<AffineMapAttr>();
+    } else if (attrName == "higherOrdering") {
+      Attribute attr;
+      RETURN_ON_FAIL(parser.parseAttribute(attr))
+
+      auto affineAttr = attr.dyn_cast<AffineMapAttr>();
       if (!affineAttr) {
         parser.emitError(parser.getNameLoc(),
                          "expected an affine map for higher ordering");
         return {};
       }
       higherOrd = affineAttr.getValue();
-    } else if (attr.getName() == "pointerBitWidth") {
-      auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
+    } else if (attrName == "pointerBitWidth") {
+      Attribute attr;
+      RETURN_ON_FAIL(parser.parseAttribute(attr))
+
+      auto intAttr = attr.dyn_cast<IntegerAttr>();
       if (!intAttr) {
         parser.emitError(parser.getNameLoc(),
                          "expected an integral pointer bitwidth");
         return {};
       }
       ptr = intAttr.getInt();
-    } else if (attr.getName() == "indexBitWidth") {
-      auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
+    } else if (attrName == "indexBitWidth") {
+      Attribute attr;
+      RETURN_ON_FAIL(parser.parseAttribute(attr))
+
+      auto intAttr = attr.dyn_cast<IntegerAttr>();
       if (!intAttr) {
         parser.emitError(parser.getNameLoc(),
                          "expected an integral index bitwidth");
         return {};
       }
       ind = intAttr.getInt();
-    } else {
-      parser.emitError(parser.getNameLoc(), "unexpected key: ")
-          << attr.getName().strref();
-      return {};
+    } else if (attrName == "slice") {
+      RETURN_ON_FAIL(parser.parseLSquare())
+      // Dispatches to DimSliceAttr to skip mnemonic
+      bool finished = false;
+      while (auto attr = SparseTensorDimSliceAttr::parse(parser, nullptr)) {
+        auto sliceAttr = attr.cast<SparseTensorDimSliceAttr>();
+        slices.push_back(sliceAttr);
+        if (parser.parseOptionalComma().failed()) {
+          finished = true;
+          break;
+        }
+      }
+      // Wrong when parsing slices
+      if (!finished)
+        return {};
+      RETURN_ON_FAIL(parser.parseRSquare())
     }
+
+    // Only the last item can omit the comma
+    if (parser.parseOptionalComma().failed())
+      break;
   }
+
+  RETURN_ON_FAIL(parser.parseRBrace())
+  RETURN_ON_FAIL(parser.parseGreater())
+#undef RETURN_ON_FAIL
+
   // Construct struct-like storage for attribute.
   return parser.getChecked<SparseTensorEncodingAttr>(
-      parser.getContext(), dlt, dimOrd, higherOrd, ptr, ind);
+      parser.getContext(), dlt, dimOrd, higherOrd, ptr, ind, slices);
 }
 
 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
@@ -188,14 +323,25 @@ void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
     printer << ", pointerBitWidth = " << getPointerBitWidth();
   if (getIndexBitWidth())
     printer << ", indexBitWidth = " << getIndexBitWidth();
+  if (!getDimSlices().empty()) {
+    printer << ", slice = [ ";
+    llvm::interleaveComma(getDimSlices(), printer,
+                          [&](SparseTensorDimSliceAttr attr) {
+                            // Calls SparseTensorDimSliceAttr::print directly to
+                            // skip mnemonic.
+                            attr.print(printer);
+                          });
+    printer << " ]";
+  }
+
   printer << " }>";
 }
 
 LogicalResult SparseTensorEncodingAttr::verify(
     function_ref<InFlightDiagnostic()> emitError,
     ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering,
-    AffineMap higherOrdering, unsigned pointerBitWidth,
-    unsigned indexBitWidth) {
+    AffineMap higherOrdering, unsigned pointerBitWidth, unsigned indexBitWidth,
+    ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
   if (!acceptBitWidth(pointerBitWidth))
     return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth;
   if (!acceptBitWidth(indexBitWidth))
@@ -217,6 +363,11 @@ LogicalResult SparseTensorEncodingAttr::verify(
       return emitError() << "unexpected mismatch in higher ordering and "
                             "dimension level types size";
   }
+  if (!dimSlices.empty() && dimSlices.size() != dimLevelType.size()) {
+    return emitError() << "unexpected mismatch in dimension slices and "
+                          "dimension level type size";
+  }
+
   return success();
 }
 
@@ -226,7 +377,7 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
   // Check structural integrity.
   if (failed(verify(emitError, getDimLevelType(), getDimOrdering(),
                     getHigherOrdering(), getPointerBitWidth(),
-                    getIndexBitWidth())))
+                    getIndexBitWidth(), getDimSlices())))
     return failure();
   // Check integrity with tensor type specifics. Dimension ordering is optional,
   // but we always should have dimension level types for the full rank.

diff  --git a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
index 4830a665bd0ac..cfbf14a799f25 100644
--- a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
@@ -68,3 +68,10 @@ func.func private @tensor_invalid_key(%arg0: tensor<16x32xf32, #a>) -> ()
 #a = #sparse_tensor.encoding<{dimLevelType = [ "compressed", "compressed", "dense", "dense" ], dimOrdering  = affine_map<(ii, jj, i, j) -> (ii, jj, i, j)>, higherOrdering = affine_map<(i, j) -> (j, i)>}> // expected-error {{unexpected higher ordering mapping from 2 to 2}}
 func.func private @tensor_invalid_key(%arg0: tensor<10x60xf32, #a>) -> ()
 
+// -----
+
+#CSR_SLICE = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "compressed" ],
+  slice = [ (-1, ?, 1), (?, 4, 2) ] // expected-error{{expect positive value or ? for slice offset/size/stride}}
+}>
+func.func private @sparse_slice(tensor<?x?xf64, #CSR_SLICE>)

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index d8eda2ac6b3db..b7bc7c619aa60 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -88,3 +88,35 @@ func.func private @sparse_bcsr(tensor<10x60xf64, #BCSR>)
 // CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ], higherOrdering = affine_map<(d0, d1)[s0] -> (d0 * (s0 * 4), d0, d1)> }>>
 func.func private @sparse_ell(tensor<?x?xf64, #ELL>)
 
+// -----
+
+#CSR_SLICE = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "compressed" ],
+  slice = [ (1, 4, 1), (1, 4, 2) ]
+}>
+
+// CHECK-LABEL: func private @sparse_slice(
+// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], slice = [ (1, 4, 1), (1, 4, 2) ] }>>
+func.func private @sparse_slice(tensor<?x?xf64, #CSR_SLICE>)
+
+// -----
+
+#CSR_SLICE = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "compressed" ],
+  slice = [ (1, 4, 1), (1, 4, 2) ]
+}>
+
+// CHECK-LABEL: func private @sparse_slice(
+// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], slice = [ (1, 4, 1), (1, 4, 2) ] }>>
+func.func private @sparse_slice(tensor<?x?xf64, #CSR_SLICE>)
+
+// -----
+
+#CSR_SLICE = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "compressed" ],
+  slice = [ (1, ?, 1), (?, 4, 2) ]
+}>
+
+// CHECK-LABEL: func private @sparse_slice(
+// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], slice = [ (1, ?, 1), (?, 4, 2) ] }>>
+func.func private @sparse_slice(tensor<?x?xf64, #CSR_SLICE>)


        


More information about the Mlir-commits mailing list