[Mlir-commits] [mlir] [mlir][sparse] support querying sparse buffer types from sparse tenso… (PR #88308)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 10 11:48:18 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Peiming Liu (PeimingLiu)

<details>
<summary>Changes</summary>

…r encodings.

---
Full diff: https://github.com/llvm/llvm-project/pull/88308.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td (+20) 
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h (+2-10) 
- (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+53-9) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index d3be8a3009ba1e..9dc7cae714fc95 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -401,6 +401,26 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     /// the null encoding (since dense-tensors are always all-ordered).
     bool isAllOrdered() const;
 
+    //
+    // storage type methods.
+    //
+
+    /// Returns the coordinate-overhead MLIR type, defaulting to `IndexType`.
+    Type getCrdElemType() const;
+
+    /// Returns the position-overhead MLIR type, defaulting to `IndexType`.
+    Type getPosElemType() const;
+
+    /// Returns the coordinate-memref MLIR type, an optional tensorDimShape is
+    /// used to refine the leading batch dimensions (if any).
+    MemRefType getCrdMemRefType(
+      std::optional<ArrayRef<int64_t>> tensorDimShape = std::nullopt) const;
+
+    /// Returns the position-memref MLIR type, an optional tensorDimShape is
+    /// used to refine the leading batch dimensions (if any).
+    MemRefType getPosMemRefType(
+      std::optional<ArrayRef<int64_t>> tensorDimShape = std::nullopt) const;
+
     //
     // dimToLvl methods.
     //
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index dc770c2e904cbd..825d89a408febe 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -328,18 +328,10 @@ class SparseTensorType {
   unsigned getPosWidth() const { return enc ? enc.getPosWidth() : 0; }
 
   /// Returns the coordinate-overhead MLIR type, defaulting to `IndexType`.
-  Type getCrdType() const {
-    if (getCrdWidth())
-      return IntegerType::get(getContext(), getCrdWidth());
-    return IndexType::get(getContext());
-  }
+  Type getCrdType() const { return enc.getCrdElemType(); }
 
   /// Returns the position-overhead MLIR type, defaulting to `IndexType`.
-  Type getPosType() const {
-    if (getPosWidth())
-      return IntegerType::get(getContext(), getPosWidth());
-    return IndexType::get(getContext());
-  }
+  Type getPosType() const { return enc.getPosElemType(); }
 
   /// Returns true iff this sparse tensor type has a trailing
   /// COO region starting at the given level. By default, it
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index e4d93c5623b9c4..e9058394d33da5 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -61,6 +61,26 @@ static constexpr bool acceptBitWidth(unsigned bitWidth) {
   }
 }
 
+static SmallVector<Size>
+getSparseFieldShape(const SparseTensorEncodingAttr enc,
+                    std::optional<ArrayRef<int64_t>> dimShape) {
+  assert(enc);
+  // With only encoding, we can not determine the static shape for leading
+  // batch levels, we therefore return a dynamic shape memref instead.
+  SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic);
+  if (dimShape.has_value()) {
+    // If the actual tensor shape is provided, we can then refine the leading
+    // batch dimension.
+    SmallVector<int64_t> lvlShape =
+        enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
+    memrefShape.assign(lvlShape.begin(),
+                       lvlShape.begin() + enc.getBatchLvlRank());
+  }
+  // Another dynamic dimension to store the sparse level.
+  memrefShape.push_back(ShapedType::kDynamic);
+  return memrefShape;
+}
+
 //===----------------------------------------------------------------------===//
 // SparseTensorDialect StorageLayout.
 //===----------------------------------------------------------------------===//
@@ -122,21 +142,17 @@ void sparse_tensor::foreachFieldAndTypeInSparseTensor(
                             LevelType)>
         callback) {
   assert(stt.hasEncoding());
-  // Construct the basic types.
-  const Type crdType = stt.getCrdType();
-  const Type posType = stt.getPosType();
-  const Type eltType = stt.getElementType();
 
-  SmallVector<int64_t> memrefShape = stt.getBatchLvlShape();
-  memrefShape.push_back(ShapedType::kDynamic);
+  SmallVector<int64_t> memrefShape =
+      getSparseFieldShape(stt.getEncoding(), stt.getDimShape());
 
   const Type specType = StorageSpecifierType::get(stt.getEncoding());
   // memref<[batch] x ? x pos>  positions
-  const Type posMemType = MemRefType::get(memrefShape, posType);
+  const Type posMemType = MemRefType::get(memrefShape, stt.getPosType());
   // memref<[batch] x ? x crd>  coordinates
-  const Type crdMemType = MemRefType::get(memrefShape, crdType);
+  const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType());
   // memref<[batch] x ? x eltType> values
-  const Type valMemType = MemRefType::get(memrefShape, eltType);
+  const Type valMemType = MemRefType::get(memrefShape, stt.getElementType());
 
   StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType,
                                    callback](FieldIndex fieldIdx,
@@ -354,6 +370,34 @@ bool SparseTensorEncodingAttr::isAllOrdered() const {
   return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT);
 }
 
+Type SparseTensorEncodingAttr::getCrdElemType() const {
+  if (!getImpl())
+    return nullptr;
+  if (getCrdWidth())
+    return IntegerType::get(getContext(), getCrdWidth());
+  return IndexType::get(getContext());
+}
+
+Type SparseTensorEncodingAttr::getPosElemType() const {
+  if (!getImpl())
+    return nullptr;
+  if (getPosWidth())
+    return IntegerType::get(getContext(), getPosWidth());
+  return IndexType::get(getContext());
+}
+
+MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
+    std::optional<ArrayRef<int64_t>> dimShape) const {
+  SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
+  return MemRefType::get(shape, getCrdElemType());
+}
+
+MemRefType SparseTensorEncodingAttr::getPosMemRefType(
+    std::optional<ArrayRef<int64_t>> dimShape) const {
+  SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
+  return MemRefType::get(shape, getPosElemType());
+}
+
 bool SparseTensorEncodingAttr::isIdentity() const {
   return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/88308


More information about the Mlir-commits mailing list