[Mlir-commits] [mlir] [mlir][sparse] support type conversion from batched sparse tensors to… (PR #83163)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 27 10:06:02 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-sparse

Author: Peiming Liu (PeimingLiu)

<details>
<summary>Changes</summary>

… memrefs.

---
Full diff: https://github.com/llvm/llvm-project/pull/83163.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td (+2) 
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h (+7-7) 
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h (+8) 
- (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+15-6) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (+2-3) 
- (modified) mlir/test/Dialect/SparseTensor/codegen.mlir (+15) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index ca98665256be5a..5d1db2323f95f0 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -374,6 +374,8 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     /// is non-null (since no fixed result is valid for every dense-tensor).
     ::mlir::sparse_tensor::Level getLvlRank() const;
 
+    uint64_t getBatchLvlRank() const;
+
     //
     // lvlTypes methods.
     //
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h
index 27dc39609cdadd..ce34ae43d1c181 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h
@@ -30,15 +30,15 @@ namespace sparse_tensor {
 ///   ;  if dense:
 ///        <nothing>
 ///   ;  if compressed:
-///        memref<? x pos>  positions   ; positions for level l
-///        memref<? x crd>  coordinates ; coordinates for level l
-///   ;  if loose-compressed:
-///        memref<? x pos>  positions   ; lo/hi position pairs for level l
-///        memref<? x crd>  coordinates ; coordinates for level l
+///        memref<[batch] x ? x pos>  positions   ; positions for level l
+///        memref<[batch] x ? x crd>  coordinates ; coordinates for level l
+///   ;  if loose-[batch] x compressed:
+///        memref<[batch] x ? x pos>  positions   ; lo/hi pos pairs for level l
+///        memref<[batch] x ? x crd>  coordinates ; coordinates for level l
 ///   ;  if singleton/2-out-of-4:
-///        memref<? x crd>  coordinates ; coordinates for level l
+///        memref<[batch] x ? x crd>  coordinates ; coordinates for level l
 ///
-///   memref<? x eltType> values        ; values
+///   memref<[batch] x ? x eltType> values        ; values
 ///
 ///   struct sparse_tensor.storage_specifier {
 ///     array<rank x int> lvlSizes    ; sizes/cardinalities for each level
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 1a090ddb782fdb..c93a4fcd922c28 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -253,6 +253,14 @@ class SparseTensorType {
                                        CrdTransDirectionKind::dim2lvl);
   }
 
+  /// Returns the Level-shape.
+  SmallVector<Size> getBatchLvlShape() const {
+    auto lvlShape = getEncoding().tranlateShape(getDimShape(),
+                                                CrdTransDirectionKind::dim2lvl);
+    lvlShape.truncate(getEncoding().getBatchLvlRank());
+    return lvlShape;
+  }
+
   /// Returns the type with an identity mapping.
   RankedTensorType getDemappedType() const {
     return RankedTensorType::get(getLvlShape(), getElementType(),
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index fd0ed26fbde072..69c3413f35ea9c 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -126,13 +126,16 @@ void sparse_tensor::foreachFieldAndTypeInSparseTensor(
   const Type posType = stt.getPosType();
   const Type eltType = stt.getElementType();
 
+  SmallVector<int64_t> memrefShape = stt.getBatchLvlShape();
+  memrefShape.push_back(ShapedType::kDynamic);
+
   const Type specType = StorageSpecifierType::get(stt.getEncoding());
-  // memref<? x pos>  positions
-  const Type posMemType = MemRefType::get({ShapedType::kDynamic}, posType);
-  // memref<? x crd>  coordinates
-  const Type crdMemType = MemRefType::get({ShapedType::kDynamic}, crdType);
-  // memref<? x eltType> values
-  const Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType);
+  // memref<[batch] x ? x pos>  positions
+  const Type posMemType = MemRefType::get(memrefShape, posType);
+  // memref<[batch] x ? x crd>  coordinates
+  const Type crdMemType = MemRefType::get(memrefShape, crdType);
+  // memref<[batch] x ? x eltType> values
+  const Type valMemType = MemRefType::get(memrefShape, eltType);
 
   StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType,
                                    callback](FieldIndex fieldIdx,
@@ -336,6 +339,12 @@ SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
   return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
 }
 
+uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const {
+  ArrayRef<LevelType> lvlTypes = getLvlTypes();
+  auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
+  return std::distance(lastBatch, lvlTypes.rend());
+}
+
 bool SparseTensorEncodingAttr::isAllDense() const {
   return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT);
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 0ccb11f3a6b858..d5eec4ae67e798 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1293,7 +1293,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
             Value tensor = fKind == SparseTensorFieldKind::ValMemRef
                                ? op.getValues()
                                : op.getLevels()[fIdx];
-
+            // TODO: handle batch.
             TypedValue<BaseMemRefType> mem = genToMemref(rewriter, loc, tensor);
             if (mem.getType().getRank() > 1) {
               // Flattens the buffer to rank 1.
@@ -1322,9 +1322,8 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
     for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
       assert(!ShapedType::isDynamic(stt.getDimShape()[lvl]));
 
-      // FIXME: dim/lvl confusion!
       // Sets up the level size.
-      auto lvlSize = constantIndex(rewriter, loc, stt.getDimShape()[lvl]);
+      auto lvlSize = constantIndex(rewriter, loc, stt.getLvlShape()[lvl]);
       desc.setLvlSize(rewriter, loc, lvl, lvlSize);
       // We use a single AOS array to store the trailing COO, so there is only
       // one memory size to set for the entire COO section.
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index c1a976c84fecca..64a515a38588a2 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -34,6 +34,10 @@
   map = (d0, d1) -> (d1 : dense, d0 : compressed)
 }>
 
+#BCSR = #sparse_tensor.encoding<{
+  map = (d0, d1, d2, d3) -> (d0: batch, d1: batch, d2 : dense, d3 : compressed)
+}>
+
 #DCSR = #sparse_tensor.encoding<{
   map = (d0, d1) -> (d0 : compressed, d1 : compressed),
   crdWidth = 64,
@@ -182,6 +186,17 @@ func.func @sparse_csr(%arg0: tensor<?x?xf64, #CSR>) {
   return
 }
 
+// CHECK-LABEL: func @sparse_bcsr(
+//  CHECK-SAME: %[[A0:.*0]]: memref<?x2x?xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<?x2x?xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<?x2x?xindex>,
+//  CHECK-SAME: %[[A3:.*]]: memref<?x2x?xf64>,
+//  CHECK-SAME: %[[A4:.*]]: !sparse_tensor.storage_specifier
+//       CHECK: return
+func.func @sparse_bcsr(%arg0: tensor<?x2x?x?xf64, #BCSR>) {
+  return
+}
+
 // CHECK-LABEL: func @sparse_dcsr(
 //  CHECK-SAME: %[[A0:.*0]]: memref<?xi32>,
 //  CHECK-SAME: %[[A1:.*1]]: memref<?xi64>,

``````````

</details>


https://github.com/llvm/llvm-project/pull/83163


More information about the Mlir-commits mailing list