[Mlir-commits] [mlir] 885a1f4 - [mlir][sparse] support parsing slices in sparse tensor encoding attribute
Peiming Liu
llvmlistbot at llvm.org
Thu Jan 12 14:35:30 PST 2023
Author: Peiming Liu
Date: 2023-01-12T22:35:24Z
New Revision: 885a1f431621c0a2bc9b17f7dae401d62646baae
URL: https://github.com/llvm/llvm-project/commit/885a1f431621c0a2bc9b17f7dae401d62646baae
DIFF: https://github.com/llvm/llvm-project/commit/885a1f431621c0a2bc9b17f7dae401d62646baae.diff
LOG: [mlir][sparse] support parsing slices in sparse tensor encoding attribute
Reviewed By: bixia
Differential Revision: https://reviews.llvm.org/D140712
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 37691253e58d6..43c493c1e0f56 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -14,15 +14,69 @@ include "mlir/IR/EnumAttr.td"
include "mlir/Dialect/SparseTensor/IR/SparseTensorBase.td"
include "mlir/IR/TensorEncoding.td"
-//===----------------------------------------------------------------------===//
-// Sparse Tensor Type Encoding Attribute
-//===----------------------------------------------------------------------===//
-
// All of the Tensor attributes will extend this class.
class SparseTensor_Attr<string name,
list<Trait> traits = []>
: AttrDef<SparseTensor_Dialect, name, traits>;
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Dimension Slice Attribute.
+//===----------------------------------------------------------------------===//
+
+def SparseTensorDimSliceAttr : SparseTensor_Attr<"SparseTensorDimSlice", []> {
+ let mnemonic = "slice";
+
+ let description = [{
+ An attribute to encode slice information of a sparse tensor on a particular
+ dimension (a tuple of offset, size, stride).
+ }];
+
+ let parameters = (
+ ins
+ "int64_t" : $offset,
+ "int64_t" : $size,
+ "int64_t" : $stride
+ );
+
+ let extraClassDeclaration = [{
+ /// Special value for dynamic offset/size/stride.
+ static constexpr int64_t kDynamic = -1;
+
+ static bool isDynamic(int64_t v) {
+ return v == kDynamic;
+ }
+
+ std::optional<uint64_t> getStaticOffset() const {
+ if (isDynamic(getOffset()))
+ return std::nullopt;
+ return static_cast<uint64_t>(getOffset());
+ };
+
+ std::optional<uint64_t> getStaticStride() const {
+ if (isDynamic(getStride()))
+ return std::nullopt;
+ return static_cast<uint64_t>(getStride());
+ }
+
+ std::optional<uint64_t> getStaticSize() const {
+ if (isDynamic(getSize()))
+ return std::nullopt;
+ return static_cast<uint64_t>(getSize());
+ }
+
+ bool isCompletelyDynamic() const {
+ return isDynamic(getOffset()) && isDynamic(getStride()) && isDynamic(getSize());
+ };
+ }];
+
+ let genVerifyDecl = 1;
+ let hasCustomAssemblyFormat = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Type Encoding Attribute.
+//===----------------------------------------------------------------------===//
+
// Sparse tensor encoding attribute.
def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
[ DeclareAttrInterfaceMethods<VerifiableTensorEncoding> ] > {
@@ -103,6 +157,9 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
choices are `8`, `16`, `32`, `64`, or, the default, `0` to indicate a
native bit width.
+ - An optional array of SparseTensorDimSliceAttr, which specifies how the sparse
+ tensor is partitioned on each level.
+
Examples:
```mlir
@@ -142,6 +199,15 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
higherOrdering = affine_map<(i, j)[c] -> (c * 4 * i, i, j)>
}>
... tensor<?x?xf64, #ELL> ...
+
+ // CSR slice (offset = 0, size = 4, stride = 1 on the first dimension;
+ // offset = 0, size = 8, and a dynamic stride on the second dimension).
+ #CSR_SLICE = #sparse_tensor.encoding<{
+ dimLevelType = [ "dense", "compressed" ],
+ slice = [ (0, 4, 1), (0, 8, ?) ]
+ }>
+ ... tensor<?x?xf64, #CSC_SLICE> ...
+
```
}];
@@ -160,9 +226,29 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
// The required bit width for pointer storage.
"unsigned":$pointerBitWidth,
// The required bit width for index storage.
- "unsigned":$indexBitWidth
+ "unsigned":$indexBitWidth,
+ // A dimension level type for each dimension of the tensor type.
+ ArrayRefParameter<
+ "::mlir::sparse_tensor::SparseTensorDimSliceAttr",
+ "per dimension slice metadata"
+ >: $dimSlices
);
+ let builders = [
+ AttrBuilder<(ins "ArrayRef<::mlir::sparse_tensor::DimLevelType>":$dimLevelType,
+ "AffineMap":$dimOrdering,
+ "AffineMap":$higherOrdering,
+ "unsigned":$pointerBitWidth,
+ "unsigned":$indexBitWidth), [{
+ return $_get($_ctxt, dimLevelType,
+ dimOrdering,
+ higherOrdering,
+ pointerBitWidth,
+ indexBitWidth,
+ ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr>{});
+ }]>
+ ];
+
let extraClassDeclaration = [{
/// Returns the type for pointer storage based on pointerBitWidth
Type getPointerType() const;
@@ -179,6 +265,17 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
/// Return true if the encoding has an identity dimension ordering.
bool hasIdDimOrdering() const;
+
+ bool isSlice() const {
+ return !getDimSlices().empty();
+ }
+
+ std::optional<uint64_t> getStaticDimSliceOffset(unsigned dim) const;
+ std::optional<uint64_t> getStaticDimSliceSize(unsigned dim) const;
+ std::optional<uint64_t> getStaticDimSliceStride(unsigned dim) const;
+ std::optional<uint64_t> getStaticLvlSliceOffset(unsigned lvl) const;
+ std::optional<uint64_t> getStaticLvlSliceSize(unsigned lvl) const;
+ std::optional<uint64_t> getStaticLvlSliceStride(unsigned lvl) const;
}];
let genVerifyDecl = 1;
@@ -186,7 +283,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
}
//===----------------------------------------------------------------------===//
-// Sparse Tensor Storage Specifier Enum Attribute
+// Sparse Tensor Storage Specifier Enum Attribute.
//===----------------------------------------------------------------------===//
// The C++ enum for Storage Specifier kind.
@@ -209,7 +306,7 @@ def SparseTensorStorageSpecifierKindAttr
}
//===----------------------------------------------------------------------===//
-// Sparse Tensor Traits
+// Sparse Tensor Traits.
//===----------------------------------------------------------------------===//
def IsSparseTensorPred
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 59a4b1360b473..b15c9c5386871 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -45,6 +45,62 @@ static bool acceptBitWidth(unsigned bitWidth) {
}
}
+void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
+ printer << "(";
+ printer << (getStaticOffset() ? std::to_string(*getStaticOffset()) : "?");
+ printer << ", ";
+ printer << (getStaticSize() ? std::to_string(*getStaticSize()) : "?");
+ printer << ", ";
+ printer << (getStaticStride() ? std::to_string(*getStaticStride()) : "?");
+ printer << ")";
+}
+
+static ParseResult parseOptionalStaticSlice(int64_t &result,
+ AsmParser &parser) {
+ auto parseResult = parser.parseOptionalInteger(result);
+ if (parseResult.has_value()) {
+ if (parseResult.value().succeeded() && result < 0) {
+ parser.emitError(
+ parser.getCurrentLocation(),
+ "expect positive value or ? for slice offset/size/stride");
+ return failure();
+ }
+ return parseResult.value();
+ }
+
+ // Else, and '?' which represented dynamic slice
+ result = SparseTensorDimSliceAttr::kDynamic;
+ return parser.parseQuestion();
+}
+
+Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
+ int64_t offset = -1, size = -1, stride = -1;
+
+ if (failed(parser.parseLParen()) ||
+ failed(parseOptionalStaticSlice(offset, parser)) ||
+ failed(parser.parseComma()) ||
+ failed(parseOptionalStaticSlice(size, parser)) ||
+ failed(parser.parseComma()) ||
+ failed(parseOptionalStaticSlice(stride, parser)) ||
+ failed(parser.parseRParen()))
+ return {};
+
+ return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(),
+ offset, size, stride);
+}
+
+LogicalResult
+SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+ int64_t offset, int64_t size, int64_t stride) {
+ if ((offset == SparseTensorDimSliceAttr::kDynamic || offset >= 0) &&
+ (size == SparseTensorDimSliceAttr::kDynamic || size > 0) &&
+ (stride == SparseTensorDimSliceAttr::kDynamic || stride > 0)) {
+ return success();
+ }
+ return emitError()
+ << "expect positive value or ? for slice offset/size/stride";
+}
+
Type SparseTensorEncodingAttr::getPointerType() const {
unsigned ptrWidth = getPointerBitWidth();
Type indexType = IndexType::get(getContext());
@@ -71,24 +127,70 @@ bool SparseTensorEncodingAttr::hasIdDimOrdering() const {
return !getDimOrdering() || getDimOrdering().isIdentity();
}
+std::optional<uint64_t>
+SparseTensorEncodingAttr::getStaticDimSliceOffset(unsigned dim) const {
+ return getDimSlices()[dim].getStaticOffset();
+}
+
+std::optional<uint64_t>
+SparseTensorEncodingAttr::getStaticDimSliceSize(unsigned dim) const {
+ return getDimSlices()[dim].getStaticSize();
+}
+
+std::optional<uint64_t>
+SparseTensorEncodingAttr::getStaticDimSliceStride(unsigned dim) const {
+ return getDimSlices()[dim].getStaticStride();
+}
+
+std::optional<uint64_t>
+SparseTensorEncodingAttr::getStaticLvlSliceOffset(unsigned lvl) const {
+ return getStaticDimSliceOffset(toOrigDim(*this, lvl));
+}
+
+std::optional<uint64_t>
+SparseTensorEncodingAttr::getStaticLvlSliceSize(unsigned lvl) const {
+ return getStaticDimSliceSize(toOrigDim(*this, lvl));
+}
+
+std::optional<uint64_t>
+SparseTensorEncodingAttr::getStaticLvlSliceStride(unsigned lvl) const {
+ return getStaticDimSliceStride(toOrigDim(*this, lvl));
+}
+
Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
- if (failed(parser.parseLess()))
- return {};
- // Parse the data as a dictionary.
- DictionaryAttr dict;
- if (failed(parser.parseAttribute(dict)))
- return {};
- if (failed(parser.parseGreater()))
- return {};
+#define RETURN_ON_FAIL(stmt) \
+ if (failed(stmt)) { \
+ return {}; \
+ }
+
+ RETURN_ON_FAIL(parser.parseLess())
+ RETURN_ON_FAIL(parser.parseLBrace())
+
// Process the data from the parsed dictionary value into struct-like data.
SmallVector<DimLevelType> dlt;
+ SmallVector<SparseTensorDimSliceAttr> slices;
AffineMap dimOrd = {};
AffineMap higherOrd = {};
unsigned ptr = 0;
unsigned ind = 0;
- for (const NamedAttribute &attr : dict) {
- if (attr.getName() == "dimLevelType") {
- auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
+
+ StringRef attrName;
+ // Exactly 6 keys.
+ SmallVector<StringRef, 6> keys = {"dimLevelType", "dimOrdering",
+ "higherOrdering", "pointerBitWidth",
+ "indexBitWidth", "slice"};
+ while (succeeded(parser.parseOptionalKeyword(&attrName))) {
+ if (!llvm::is_contained(keys, attrName)) {
+ parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName;
+ return {};
+ }
+
+ // Consume the `=` after keys
+ RETURN_ON_FAIL(parser.parseEqual())
+ if (attrName == "dimLevelType") {
+ Attribute attr;
+ RETURN_ON_FAIL(parser.parseAttribute(attr));
+ auto arrayAttr = attr.dyn_cast<ArrayAttr>();
if (!arrayAttr) {
parser.emitError(parser.getNameLoc(),
"expected an array for dimension level types");
@@ -127,47 +229,80 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
return {};
}
}
- } else if (attr.getName() == "dimOrdering") {
- auto affineAttr = attr.getValue().dyn_cast<AffineMapAttr>();
+ } else if (attrName == "dimOrdering") {
+ Attribute attr;
+ RETURN_ON_FAIL(parser.parseAttribute(attr))
+
+ auto affineAttr = attr.dyn_cast<AffineMapAttr>();
if (!affineAttr) {
parser.emitError(parser.getNameLoc(),
"expected an affine map for dimension ordering");
return {};
}
dimOrd = affineAttr.getValue();
- } else if (attr.getName() == "higherOrdering") {
- auto affineAttr = attr.getValue().dyn_cast<AffineMapAttr>();
+ } else if (attrName == "higherOrdering") {
+ Attribute attr;
+ RETURN_ON_FAIL(parser.parseAttribute(attr))
+
+ auto affineAttr = attr.dyn_cast<AffineMapAttr>();
if (!affineAttr) {
parser.emitError(parser.getNameLoc(),
"expected an affine map for higher ordering");
return {};
}
higherOrd = affineAttr.getValue();
- } else if (attr.getName() == "pointerBitWidth") {
- auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
+ } else if (attrName == "pointerBitWidth") {
+ Attribute attr;
+ RETURN_ON_FAIL(parser.parseAttribute(attr))
+
+ auto intAttr = attr.dyn_cast<IntegerAttr>();
if (!intAttr) {
parser.emitError(parser.getNameLoc(),
"expected an integral pointer bitwidth");
return {};
}
ptr = intAttr.getInt();
- } else if (attr.getName() == "indexBitWidth") {
- auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
+ } else if (attrName == "indexBitWidth") {
+ Attribute attr;
+ RETURN_ON_FAIL(parser.parseAttribute(attr))
+
+ auto intAttr = attr.dyn_cast<IntegerAttr>();
if (!intAttr) {
parser.emitError(parser.getNameLoc(),
"expected an integral index bitwidth");
return {};
}
ind = intAttr.getInt();
- } else {
- parser.emitError(parser.getNameLoc(), "unexpected key: ")
- << attr.getName().strref();
- return {};
+ } else if (attrName == "slice") {
+ RETURN_ON_FAIL(parser.parseLSquare())
+ // Dispatches to DimSliceAttr to skip mnemonic
+ bool finished = false;
+ while (auto attr = SparseTensorDimSliceAttr::parse(parser, nullptr)) {
+ auto sliceAttr = attr.cast<SparseTensorDimSliceAttr>();
+ slices.push_back(sliceAttr);
+ if (parser.parseOptionalComma().failed()) {
+ finished = true;
+ break;
+ }
+ }
+ // Wrong when parsing slices
+ if (!finished)
+ return {};
+ RETURN_ON_FAIL(parser.parseRSquare())
}
+
+ // Only the last item can omit the comma
+ if (parser.parseOptionalComma().failed())
+ break;
}
+
+ RETURN_ON_FAIL(parser.parseRBrace())
+ RETURN_ON_FAIL(parser.parseGreater())
+#undef RETURN_ON_FAIL
+
// Construct struct-like storage for attribute.
return parser.getChecked<SparseTensorEncodingAttr>(
- parser.getContext(), dlt, dimOrd, higherOrd, ptr, ind);
+ parser.getContext(), dlt, dimOrd, higherOrd, ptr, ind, slices);
}
void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
@@ -188,14 +323,25 @@ void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
printer << ", pointerBitWidth = " << getPointerBitWidth();
if (getIndexBitWidth())
printer << ", indexBitWidth = " << getIndexBitWidth();
+ if (!getDimSlices().empty()) {
+ printer << ", slice = [ ";
+ llvm::interleaveComma(getDimSlices(), printer,
+ [&](SparseTensorDimSliceAttr attr) {
+ // Calls SparseTensorDimSliceAttr::print directly to
+ // skip mnemonic.
+ attr.print(printer);
+ });
+ printer << " ]";
+ }
+
printer << " }>";
}
LogicalResult SparseTensorEncodingAttr::verify(
function_ref<InFlightDiagnostic()> emitError,
ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering,
- AffineMap higherOrdering, unsigned pointerBitWidth,
- unsigned indexBitWidth) {
+ AffineMap higherOrdering, unsigned pointerBitWidth, unsigned indexBitWidth,
+ ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
if (!acceptBitWidth(pointerBitWidth))
return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth;
if (!acceptBitWidth(indexBitWidth))
@@ -217,6 +363,11 @@ LogicalResult SparseTensorEncodingAttr::verify(
return emitError() << "unexpected mismatch in higher ordering and "
"dimension level types size";
}
+ if (!dimSlices.empty() && dimSlices.size() != dimLevelType.size()) {
+ return emitError() << "unexpected mismatch in dimension slices and "
+ "dimension level type size";
+ }
+
return success();
}
@@ -226,7 +377,7 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
// Check structural integrity.
if (failed(verify(emitError, getDimLevelType(), getDimOrdering(),
getHigherOrdering(), getPointerBitWidth(),
- getIndexBitWidth())))
+ getIndexBitWidth(), getDimSlices())))
return failure();
// Check integrity with tensor type specifics. Dimension ordering is optional,
// but we always should have dimension level types for the full rank.
diff --git a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
index 4830a665bd0ac..cfbf14a799f25 100644
--- a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
@@ -68,3 +68,10 @@ func.func private @tensor_invalid_key(%arg0: tensor<16x32xf32, #a>) -> ()
#a = #sparse_tensor.encoding<{dimLevelType = [ "compressed", "compressed", "dense", "dense" ], dimOrdering = affine_map<(ii, jj, i, j) -> (ii, jj, i, j)>, higherOrdering = affine_map<(i, j) -> (j, i)>}> // expected-error {{unexpected higher ordering mapping from 2 to 2}}
func.func private @tensor_invalid_key(%arg0: tensor<10x60xf32, #a>) -> ()
+// -----
+
+#CSR_SLICE = #sparse_tensor.encoding<{
+ dimLevelType = [ "dense", "compressed" ],
+ slice = [ (-1, ?, 1), (?, 4, 2) ] // expected-error{{expect positive value or ? for slice offset/size/stride}}
+}>
+func.func private @sparse_slice(tensor<?x?xf64, #CSR_SLICE>)
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index d8eda2ac6b3db..b7bc7c619aa60 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -88,3 +88,35 @@ func.func private @sparse_bcsr(tensor<10x60xf64, #BCSR>)
// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ], higherOrdering = affine_map<(d0, d1)[s0] -> (d0 * (s0 * 4), d0, d1)> }>>
func.func private @sparse_ell(tensor<?x?xf64, #ELL>)
+// -----
+
+#CSR_SLICE = #sparse_tensor.encoding<{
+ dimLevelType = [ "dense", "compressed" ],
+ slice = [ (1, 4, 1), (1, 4, 2) ]
+}>
+
+// CHECK-LABEL: func private @sparse_slice(
+// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], slice = [ (1, 4, 1), (1, 4, 2) ] }>>
+func.func private @sparse_slice(tensor<?x?xf64, #CSR_SLICE>)
+
+// -----
+
+#CSR_SLICE = #sparse_tensor.encoding<{
+ dimLevelType = [ "dense", "compressed" ],
+ slice = [ (1, 4, 1), (1, 4, 2) ]
+}>
+
+// CHECK-LABEL: func private @sparse_slice(
+// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], slice = [ (1, 4, 1), (1, 4, 2) ] }>>
+func.func private @sparse_slice(tensor<?x?xf64, #CSR_SLICE>)
+
+// -----
+
+#CSR_SLICE = #sparse_tensor.encoding<{
+ dimLevelType = [ "dense", "compressed" ],
+ slice = [ (1, ?, 1), (?, 4, 2) ]
+}>
+
+// CHECK-LABEL: func private @sparse_slice(
+// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], slice = [ (1, ?, 1), (?, 4, 2) ] }>>
+func.func private @sparse_slice(tensor<?x?xf64, #CSR_SLICE>)
More information about the Mlir-commits
mailing list