[Mlir-commits] [mlir] [mlir][sparse] support type conversion from batched sparse tensors to… (PR #83163)

Peiming Liu llvmlistbot at llvm.org
Tue Feb 27 10:14:51 PST 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/83163

>From f593e936fb5a10e520af6403d97eeda70d42f3e5 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 27 Feb 2024 18:03:04 +0000
Subject: [PATCH] [mlir][sparse] support type conversion from batched sparse
 tensors to memrefs.

---
 .../mlir/Dialect/SparseTensor/IR/Enums.h      |  2 +-
 .../SparseTensor/IR/SparseTensorAttrDefs.td   |  2 ++
 .../IR/SparseTensorStorageLayout.h            | 14 ++++++-------
 .../SparseTensor/IR/SparseTensorType.h        |  8 +++++++
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 21 +++++++++++++------
 .../Transforms/SparseTensorCodegen.cpp        |  5 ++---
 mlir/test/Dialect/SparseTensor/codegen.mlir   | 15 +++++++++++++
 7 files changed, 50 insertions(+), 17 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index cc134e7d953ec6..9e79b6aca1c9ba 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -342,7 +342,7 @@ struct LevelType {
   /// Check if the `LevelType` needs coordinates array.
   constexpr bool isWithCrdLT() const {
     // All sparse levels has coordinate array.
-    return !isa<LevelFormat::Dense>();
+    return !isa<LevelFormat::Dense, LevelFormat::Batch>();
   }
 
   std::string toMLIRString() const {
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index ca98665256be5a..5d1db2323f95f0 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -374,6 +374,8 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     /// is non-null (since no fixed result is valid for every dense-tensor).
     ::mlir::sparse_tensor::Level getLvlRank() const;
 
+    uint64_t getBatchLvlRank() const;
+
     //
     // lvlTypes methods.
     //
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h
index 27dc39609cdadd..ce34ae43d1c181 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h
@@ -30,15 +30,15 @@ namespace sparse_tensor {
 ///   ;  if dense:
 ///        <nothing>
 ///   ;  if compressed:
-///        memref<? x pos>  positions   ; positions for level l
-///        memref<? x crd>  coordinates ; coordinates for level l
-///   ;  if loose-compressed:
-///        memref<? x pos>  positions   ; lo/hi position pairs for level l
-///        memref<? x crd>  coordinates ; coordinates for level l
+///        memref<[batch] x ? x pos>  positions   ; positions for level l
+///        memref<[batch] x ? x crd>  coordinates ; coordinates for level l
+///   ;  if loose-[batch] x compressed:
+///        memref<[batch] x ? x pos>  positions   ; lo/hi pos pairs for level l
+///        memref<[batch] x ? x crd>  coordinates ; coordinates for level l
 ///   ;  if singleton/2-out-of-4:
-///        memref<? x crd>  coordinates ; coordinates for level l
+///        memref<[batch] x ? x crd>  coordinates ; coordinates for level l
 ///
-///   memref<? x eltType> values        ; values
+///   memref<[batch] x ? x eltType> values        ; values
 ///
 ///   struct sparse_tensor.storage_specifier {
 ///     array<rank x int> lvlSizes    ; sizes/cardinalities for each level
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 1a090ddb782fdb..c93a4fcd922c28 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -253,6 +253,14 @@ class SparseTensorType {
                                        CrdTransDirectionKind::dim2lvl);
   }
 
+  /// Returns the Level-shape.
+  SmallVector<Size> getBatchLvlShape() const {
+    auto lvlShape = getEncoding().tranlateShape(getDimShape(),
+                                                CrdTransDirectionKind::dim2lvl);
+    lvlShape.truncate(getEncoding().getBatchLvlRank());
+    return lvlShape;
+  }
+
   /// Returns the type with an identity mapping.
   RankedTensorType getDemappedType() const {
     return RankedTensorType::get(getLvlShape(), getElementType(),
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index fd0ed26fbde072..69c3413f35ea9c 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -126,13 +126,16 @@ void sparse_tensor::foreachFieldAndTypeInSparseTensor(
   const Type posType = stt.getPosType();
   const Type eltType = stt.getElementType();
 
+  SmallVector<int64_t> memrefShape = stt.getBatchLvlShape();
+  memrefShape.push_back(ShapedType::kDynamic);
+
   const Type specType = StorageSpecifierType::get(stt.getEncoding());
-  // memref<? x pos>  positions
-  const Type posMemType = MemRefType::get({ShapedType::kDynamic}, posType);
-  // memref<? x crd>  coordinates
-  const Type crdMemType = MemRefType::get({ShapedType::kDynamic}, crdType);
-  // memref<? x eltType> values
-  const Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType);
+  // memref<[batch] x ? x pos>  positions
+  const Type posMemType = MemRefType::get(memrefShape, posType);
+  // memref<[batch] x ? x crd>  coordinates
+  const Type crdMemType = MemRefType::get(memrefShape, crdType);
+  // memref<[batch] x ? x eltType> values
+  const Type valMemType = MemRefType::get(memrefShape, eltType);
 
   StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType,
                                    callback](FieldIndex fieldIdx,
@@ -336,6 +339,12 @@ SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
   return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
 }
 
+uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const {
+  ArrayRef<LevelType> lvlTypes = getLvlTypes();
+  auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
+  return std::distance(lastBatch, lvlTypes.rend());
+}
+
 bool SparseTensorEncodingAttr::isAllDense() const {
   return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT);
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 0ccb11f3a6b858..d5eec4ae67e798 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1293,7 +1293,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
             Value tensor = fKind == SparseTensorFieldKind::ValMemRef
                                ? op.getValues()
                                : op.getLevels()[fIdx];
-
+            // TODO: handle batch.
             TypedValue<BaseMemRefType> mem = genToMemref(rewriter, loc, tensor);
             if (mem.getType().getRank() > 1) {
               // Flattens the buffer to rank 1.
@@ -1322,9 +1322,8 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
     for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
       assert(!ShapedType::isDynamic(stt.getDimShape()[lvl]));
 
-      // FIXME: dim/lvl confusion!
       // Sets up the level size.
-      auto lvlSize = constantIndex(rewriter, loc, stt.getDimShape()[lvl]);
+      auto lvlSize = constantIndex(rewriter, loc, stt.getLvlShape()[lvl]);
       desc.setLvlSize(rewriter, loc, lvl, lvlSize);
       // We use a single AOS array to store the trailing COO, so there is only
       // one memory size to set for the entire COO section.
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index c1a976c84fecca..64a515a38588a2 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -34,6 +34,10 @@
   map = (d0, d1) -> (d1 : dense, d0 : compressed)
 }>
 
+#BCSR = #sparse_tensor.encoding<{
+  map = (d0, d1, d2, d3) -> (d0: batch, d1: batch, d2 : dense, d3 : compressed)
+}>
+
 #DCSR = #sparse_tensor.encoding<{
   map = (d0, d1) -> (d0 : compressed, d1 : compressed),
   crdWidth = 64,
@@ -182,6 +186,17 @@ func.func @sparse_csr(%arg0: tensor<?x?xf64, #CSR>) {
   return
 }
 
+// CHECK-LABEL: func @sparse_bcsr(
+//  CHECK-SAME: %[[A0:.*0]]: memref<?x2x?xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<?x2x?xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<?x2x?xindex>,
+//  CHECK-SAME: %[[A3:.*]]: memref<?x2x?xf64>,
+//  CHECK-SAME: %[[A4:.*]]: !sparse_tensor.storage_specifier
+//       CHECK: return
+func.func @sparse_bcsr(%arg0: tensor<?x2x?x?xf64, #BCSR>) {
+  return
+}
+
 // CHECK-LABEL: func @sparse_dcsr(
 //  CHECK-SAME: %[[A0:.*0]]: memref<?xi32>,
 //  CHECK-SAME: %[[A1:.*1]]: memref<?xi64>,



More information about the Mlir-commits mailing list