[Mlir-commits] [mlir] [mlir][sparse] support querying sparse buffer types from sparse tenso… (PR #88308)
Peiming Liu
llvmlistbot at llvm.org
Wed Apr 10 11:43:44 PDT 2024
https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/88308
…r encodings.
>From 40457f317187863e79c0ecda2e0e2313c240c99c Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 10 Apr 2024 18:26:48 +0000
Subject: [PATCH] [mlir][sparse] support querying sparse buffer types from
sparse tensor encodings.
---
.../SparseTensor/IR/SparseTensorAttrDefs.td | 20 ++++++
.../SparseTensor/IR/SparseTensorType.h | 12 +---
.../SparseTensor/IR/SparseTensorDialect.cpp | 62 ++++++++++++++++---
3 files changed, 75 insertions(+), 19 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index d3be8a3009ba1e..9dc7cae714fc95 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -401,6 +401,26 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
/// the null encoding (since dense-tensors are always all-ordered).
bool isAllOrdered() const;
+ //
+ // storage type methods.
+ //
+
+ /// Returns the coordinate-overhead MLIR type, defaulting to `IndexType`.
+ Type getCrdElemType() const;
+
+ /// Returns the position-overhead MLIR type, defaulting to `IndexType`.
+ Type getPosElemType() const;
+
+ /// Returns the coordinate-memref MLIR type, an optional tensorDimShape is
+ /// used to refine the leading batch dimensions (if any).
+ MemRefType getCrdMemRefType(
+ std::optional<ArrayRef<int64_t>> tensorDimShape = std::nullopt) const;
+
+ /// Returns the position-memref MLIR type, an optional tensorDimShape is
+ /// used to refine the leading batch dimensions (if any).
+ MemRefType getPosMemRefType(
+ std::optional<ArrayRef<int64_t>> tensorDimShape = std::nullopt) const;
+
//
// dimToLvl methods.
//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index dc770c2e904cbd..825d89a408febe 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -328,18 +328,10 @@ class SparseTensorType {
unsigned getPosWidth() const { return enc ? enc.getPosWidth() : 0; }
/// Returns the coordinate-overhead MLIR type, defaulting to `IndexType`.
- Type getCrdType() const {
- if (getCrdWidth())
- return IntegerType::get(getContext(), getCrdWidth());
- return IndexType::get(getContext());
- }
+ Type getCrdType() const { return enc.getCrdElemType(); }
/// Returns the position-overhead MLIR type, defaulting to `IndexType`.
- Type getPosType() const {
- if (getPosWidth())
- return IntegerType::get(getContext(), getPosWidth());
- return IndexType::get(getContext());
- }
+ Type getPosType() const { return enc.getPosElemType(); }
/// Returns true iff this sparse tensor type has a trailing
/// COO region starting at the given level. By default, it
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index e4d93c5623b9c4..e9058394d33da5 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -61,6 +61,26 @@ static constexpr bool acceptBitWidth(unsigned bitWidth) {
}
}
+static SmallVector<Size>
+getSparseFieldShape(const SparseTensorEncodingAttr enc,
+ std::optional<ArrayRef<int64_t>> dimShape) {
+ assert(enc);
+ // With only encoding, we can not determine the static shape for leading
+ // batch levels, we therefore return a dynamic shape memref instead.
+ SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic);
+ if (dimShape.has_value()) {
+ // If the actual tensor shape is provided, we can then refine the leading
+ // batch dimension.
+ SmallVector<int64_t> lvlShape =
+ enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
+ memrefShape.assign(lvlShape.begin(),
+ lvlShape.begin() + enc.getBatchLvlRank());
+ }
+ // Another dynamic dimension to store the sparse level.
+ memrefShape.push_back(ShapedType::kDynamic);
+ return memrefShape;
+}
+
//===----------------------------------------------------------------------===//
// SparseTensorDialect StorageLayout.
//===----------------------------------------------------------------------===//
@@ -122,21 +142,17 @@ void sparse_tensor::foreachFieldAndTypeInSparseTensor(
LevelType)>
callback) {
assert(stt.hasEncoding());
- // Construct the basic types.
- const Type crdType = stt.getCrdType();
- const Type posType = stt.getPosType();
- const Type eltType = stt.getElementType();
- SmallVector<int64_t> memrefShape = stt.getBatchLvlShape();
- memrefShape.push_back(ShapedType::kDynamic);
+ SmallVector<int64_t> memrefShape =
+ getSparseFieldShape(stt.getEncoding(), stt.getDimShape());
const Type specType = StorageSpecifierType::get(stt.getEncoding());
// memref<[batch] x ? x pos> positions
- const Type posMemType = MemRefType::get(memrefShape, posType);
+ const Type posMemType = MemRefType::get(memrefShape, stt.getPosType());
// memref<[batch] x ? x crd> coordinates
- const Type crdMemType = MemRefType::get(memrefShape, crdType);
+ const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType());
// memref<[batch] x ? x eltType> values
- const Type valMemType = MemRefType::get(memrefShape, eltType);
+ const Type valMemType = MemRefType::get(memrefShape, stt.getElementType());
StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType,
callback](FieldIndex fieldIdx,
@@ -354,6 +370,34 @@ bool SparseTensorEncodingAttr::isAllOrdered() const {
return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT);
}
+Type SparseTensorEncodingAttr::getCrdElemType() const {
+ if (!getImpl())
+ return nullptr;
+ if (getCrdWidth())
+ return IntegerType::get(getContext(), getCrdWidth());
+ return IndexType::get(getContext());
+}
+
+Type SparseTensorEncodingAttr::getPosElemType() const {
+ if (!getImpl())
+ return nullptr;
+ if (getPosWidth())
+ return IntegerType::get(getContext(), getPosWidth());
+ return IndexType::get(getContext());
+}
+
+MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
+ std::optional<ArrayRef<int64_t>> dimShape) const {
+ SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
+ return MemRefType::get(shape, getCrdElemType());
+}
+
+MemRefType SparseTensorEncodingAttr::getPosMemRefType(
+ std::optional<ArrayRef<int64_t>> dimShape) const {
+ SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
+ return MemRefType::get(shape, getPosElemType());
+}
+
bool SparseTensorEncodingAttr::isIdentity() const {
return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
}
More information about the Mlir-commits
mailing list