[Mlir-commits] [mlir] [mlir][sparse] support querying sparse buffer types from sparse tenso… (PR #88308)

Peiming Liu llvmlistbot at llvm.org
Thu Apr 11 09:42:27 PDT 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/88308

>From 40457f317187863e79c0ecda2e0e2313c240c99c Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 10 Apr 2024 18:26:48 +0000
Subject: [PATCH 1/2] [mlir][sparse] support querying sparse buffer types from
 sparse tensor encodings.

---
 .../SparseTensor/IR/SparseTensorAttrDefs.td   | 20 ++++++
 .../SparseTensor/IR/SparseTensorType.h        | 12 +---
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 62 ++++++++++++++++---
 3 files changed, 75 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index d3be8a3009ba1e..9dc7cae714fc95 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -401,6 +401,26 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     /// the null encoding (since dense-tensors are always all-ordered).
     bool isAllOrdered() const;
 
+    //
+    // storage type methods.
+    //
+
+    /// Returns the coordinate-overhead MLIR type, defaulting to `IndexType`.
+    Type getCrdElemType() const;
+
+    /// Returns the position-overhead MLIR type, defaulting to `IndexType`.
+    Type getPosElemType() const;
+
+    /// Returns the coordinate-memref MLIR type, an optional tensorDimShape is
+    /// used to refine the leading batch dimensions (if any).
+    MemRefType getCrdMemRefType(
+      std::optional<ArrayRef<int64_t>> tensorDimShape = std::nullopt) const;
+
+    /// Returns the position-memref MLIR type, an optional tensorDimShape is
+    /// used to refine the leading batch dimensions (if any).
+    MemRefType getPosMemRefType(
+      std::optional<ArrayRef<int64_t>> tensorDimShape = std::nullopt) const;
+
     //
     // dimToLvl methods.
     //
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index dc770c2e904cbd..825d89a408febe 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -328,18 +328,10 @@ class SparseTensorType {
   unsigned getPosWidth() const { return enc ? enc.getPosWidth() : 0; }
 
   /// Returns the coordinate-overhead MLIR type, defaulting to `IndexType`.
-  Type getCrdType() const {
-    if (getCrdWidth())
-      return IntegerType::get(getContext(), getCrdWidth());
-    return IndexType::get(getContext());
-  }
+  Type getCrdType() const { return enc.getCrdElemType(); }
 
   /// Returns the position-overhead MLIR type, defaulting to `IndexType`.
-  Type getPosType() const {
-    if (getPosWidth())
-      return IntegerType::get(getContext(), getPosWidth());
-    return IndexType::get(getContext());
-  }
+  Type getPosType() const { return enc.getPosElemType(); }
 
   /// Returns true iff this sparse tensor type has a trailing
   /// COO region starting at the given level. By default, it
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index e4d93c5623b9c4..e9058394d33da5 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -61,6 +61,26 @@ static constexpr bool acceptBitWidth(unsigned bitWidth) {
   }
 }
 
+static SmallVector<Size>
+getSparseFieldShape(const SparseTensorEncodingAttr enc,
+                    std::optional<ArrayRef<int64_t>> dimShape) {
+  assert(enc);
+  // With only encoding, we can not determine the static shape for leading
+  // batch levels, we therefore return a dynamic shape memref instead.
+  SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic);
+  if (dimShape.has_value()) {
+    // If the actual tensor shape is provided, we can then refine the leading
+    // batch dimension.
+    SmallVector<int64_t> lvlShape =
+        enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
+    memrefShape.assign(lvlShape.begin(),
+                       lvlShape.begin() + enc.getBatchLvlRank());
+  }
+  // Another dynamic dimension to store the sparse level.
+  memrefShape.push_back(ShapedType::kDynamic);
+  return memrefShape;
+}
+
 //===----------------------------------------------------------------------===//
 // SparseTensorDialect StorageLayout.
 //===----------------------------------------------------------------------===//
@@ -122,21 +142,17 @@ void sparse_tensor::foreachFieldAndTypeInSparseTensor(
                             LevelType)>
         callback) {
   assert(stt.hasEncoding());
-  // Construct the basic types.
-  const Type crdType = stt.getCrdType();
-  const Type posType = stt.getPosType();
-  const Type eltType = stt.getElementType();
 
-  SmallVector<int64_t> memrefShape = stt.getBatchLvlShape();
-  memrefShape.push_back(ShapedType::kDynamic);
+  SmallVector<int64_t> memrefShape =
+      getSparseFieldShape(stt.getEncoding(), stt.getDimShape());
 
   const Type specType = StorageSpecifierType::get(stt.getEncoding());
   // memref<[batch] x ? x pos>  positions
-  const Type posMemType = MemRefType::get(memrefShape, posType);
+  const Type posMemType = MemRefType::get(memrefShape, stt.getPosType());
   // memref<[batch] x ? x crd>  coordinates
-  const Type crdMemType = MemRefType::get(memrefShape, crdType);
+  const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType());
   // memref<[batch] x ? x eltType> values
-  const Type valMemType = MemRefType::get(memrefShape, eltType);
+  const Type valMemType = MemRefType::get(memrefShape, stt.getElementType());
 
   StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType,
                                    callback](FieldIndex fieldIdx,
@@ -354,6 +370,34 @@ bool SparseTensorEncodingAttr::isAllOrdered() const {
   return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT);
 }
 
+Type SparseTensorEncodingAttr::getCrdElemType() const {
+  if (!getImpl())
+    return nullptr;
+  if (getCrdWidth())
+    return IntegerType::get(getContext(), getCrdWidth());
+  return IndexType::get(getContext());
+}
+
+Type SparseTensorEncodingAttr::getPosElemType() const {
+  if (!getImpl())
+    return nullptr;
+  if (getPosWidth())
+    return IntegerType::get(getContext(), getPosWidth());
+  return IndexType::get(getContext());
+}
+
+MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
+    std::optional<ArrayRef<int64_t>> dimShape) const {
+  SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
+  return MemRefType::get(shape, getCrdElemType());
+}
+
+MemRefType SparseTensorEncodingAttr::getPosMemRefType(
+    std::optional<ArrayRef<int64_t>> dimShape) const {
+  SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
+  return MemRefType::get(shape, getPosElemType());
+}
+
 bool SparseTensorEncodingAttr::isIdentity() const {
   return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
 }

>From 93b58f3c4b3eab91c72d5009a8cd8d4ba984d561 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 11 Apr 2024 16:42:12 +0000
Subject: [PATCH 2/2] address comments

---
 .../mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td        | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 9dc7cae714fc95..4a9b9169ae4b86 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -402,7 +402,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     bool isAllOrdered() const;
 
     //
-    // storage type methods.
+    // Storage type methods.
     //
 
     /// Returns the coordinate-overhead MLIR type, defaulting to `IndexType`.



More information about the Mlir-commits mailing list