[Mlir-commits] [mlir] [mlir][sparse] support type conversion from batched sparse tensors to… (PR #83163)
Peiming Liu
llvmlistbot at llvm.org
Tue Feb 27 10:15:35 PST 2024
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/83163
>From 5fae104e5fc3b2cfc873944f0efe043190490722 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 27 Feb 2024 18:03:04 +0000
Subject: [PATCH] [mlir][sparse] support type conversion from batched sparse
tensors to memrefs.
---
.../mlir/Dialect/SparseTensor/IR/Enums.h | 2 +-
.../SparseTensor/IR/SparseTensorAttrDefs.td | 2 ++
.../IR/SparseTensorStorageLayout.h | 14 ++++++-------
.../SparseTensor/IR/SparseTensorType.h | 8 +++++++
.../SparseTensor/IR/SparseTensorDialect.cpp | 21 +++++++++++++------
.../Transforms/SparseTensorCodegen.cpp | 5 ++---
mlir/test/Dialect/SparseTensor/codegen.mlir | 14 +++++++++++++
7 files changed, 49 insertions(+), 17 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index cc134e7d953ec6..9e79b6aca1c9ba 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -342,7 +342,7 @@ struct LevelType {
/// Check if the `LevelType` needs coordinates array.
constexpr bool isWithCrdLT() const {
// All sparse levels has coordinate array.
- return !isa<LevelFormat::Dense>();
+ return !isa<LevelFormat::Dense, LevelFormat::Batch>();
}
std::string toMLIRString() const {
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index ca98665256be5a..5d1db2323f95f0 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -374,6 +374,8 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
/// is non-null (since no fixed result is valid for every dense-tensor).
::mlir::sparse_tensor::Level getLvlRank() const;
+ uint64_t getBatchLvlRank() const;
+
//
// lvlTypes methods.
//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h
index 27dc39609cdadd..ce34ae43d1c181 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h
@@ -30,15 +30,15 @@ namespace sparse_tensor {
/// ; if dense:
/// <nothing>
/// ; if compressed:
-/// memref<? x pos> positions ; positions for level l
-/// memref<? x crd> coordinates ; coordinates for level l
-/// ; if loose-compressed:
-/// memref<? x pos> positions ; lo/hi position pairs for level l
-/// memref<? x crd> coordinates ; coordinates for level l
+/// memref<[batch] x ? x pos> positions ; positions for level l
+/// memref<[batch] x ? x crd> coordinates ; coordinates for level l
+/// ; if loose-[batch] x compressed:
+/// memref<[batch] x ? x pos> positions ; lo/hi pos pairs for level l
+/// memref<[batch] x ? x crd> coordinates ; coordinates for level l
/// ; if singleton/2-out-of-4:
-/// memref<? x crd> coordinates ; coordinates for level l
+/// memref<[batch] x ? x crd> coordinates ; coordinates for level l
///
-/// memref<? x eltType> values ; values
+/// memref<[batch] x ? x eltType> values ; values
///
/// struct sparse_tensor.storage_specifier {
/// array<rank x int> lvlSizes ; sizes/cardinalities for each level
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 1a090ddb782fdb..c93a4fcd922c28 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -253,6 +253,14 @@ class SparseTensorType {
CrdTransDirectionKind::dim2lvl);
}
+ /// Returns the Level-shape.
+ SmallVector<Size> getBatchLvlShape() const {
+ auto lvlShape = getEncoding().tranlateShape(getDimShape(),
+ CrdTransDirectionKind::dim2lvl);
+ lvlShape.truncate(getEncoding().getBatchLvlRank());
+ return lvlShape;
+ }
+
/// Returns the type with an identity mapping.
RankedTensorType getDemappedType() const {
return RankedTensorType::get(getLvlShape(), getElementType(),
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index fd0ed26fbde072..69c3413f35ea9c 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -126,13 +126,16 @@ void sparse_tensor::foreachFieldAndTypeInSparseTensor(
const Type posType = stt.getPosType();
const Type eltType = stt.getElementType();
+ SmallVector<int64_t> memrefShape = stt.getBatchLvlShape();
+ memrefShape.push_back(ShapedType::kDynamic);
+
const Type specType = StorageSpecifierType::get(stt.getEncoding());
- // memref<? x pos> positions
- const Type posMemType = MemRefType::get({ShapedType::kDynamic}, posType);
- // memref<? x crd> coordinates
- const Type crdMemType = MemRefType::get({ShapedType::kDynamic}, crdType);
- // memref<? x eltType> values
- const Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType);
+ // memref<[batch] x ? x pos> positions
+ const Type posMemType = MemRefType::get(memrefShape, posType);
+ // memref<[batch] x ? x crd> coordinates
+ const Type crdMemType = MemRefType::get(memrefShape, crdType);
+ // memref<[batch] x ? x eltType> values
+ const Type valMemType = MemRefType::get(memrefShape, eltType);
StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType,
callback](FieldIndex fieldIdx,
@@ -336,6 +339,12 @@ SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
}
+uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const {
+ ArrayRef<LevelType> lvlTypes = getLvlTypes();
+ auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
+ return std::distance(lastBatch, lvlTypes.rend());
+}
+
bool SparseTensorEncodingAttr::isAllDense() const {
return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 0ccb11f3a6b858..d5eec4ae67e798 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1293,7 +1293,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
Value tensor = fKind == SparseTensorFieldKind::ValMemRef
? op.getValues()
: op.getLevels()[fIdx];
-
+ // TODO: handle batch.
TypedValue<BaseMemRefType> mem = genToMemref(rewriter, loc, tensor);
if (mem.getType().getRank() > 1) {
// Flattens the buffer to rank 1.
@@ -1322,9 +1322,8 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
assert(!ShapedType::isDynamic(stt.getDimShape()[lvl]));
- // FIXME: dim/lvl confusion!
// Sets up the level size.
- auto lvlSize = constantIndex(rewriter, loc, stt.getDimShape()[lvl]);
+ auto lvlSize = constantIndex(rewriter, loc, stt.getLvlShape()[lvl]);
desc.setLvlSize(rewriter, loc, lvl, lvlSize);
// We use a single AOS array to store the trailing COO, so there is only
// one memory size to set for the entire COO section.
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index c1a976c84fecca..c32b2e0c2cbf58 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -34,6 +34,10 @@
map = (d0, d1) -> (d1 : dense, d0 : compressed)
}>
+#BCSR = #sparse_tensor.encoding<{
+ map = (d0, d1, d2, d3) -> (d0: batch, d1: batch, d2 : dense, d3 : compressed)
+}>
+
#DCSR = #sparse_tensor.encoding<{
map = (d0, d1) -> (d0 : compressed, d1 : compressed),
crdWidth = 64,
@@ -182,6 +186,16 @@ func.func @sparse_csr(%arg0: tensor<?x?xf64, #CSR>) {
return
}
+// CHECK-LABEL: func @sparse_bcsr(
+// CHECK-SAME: %[[A1:.*1]]: memref<?x2x?xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?x2x?xindex>,
+// CHECK-SAME: %[[A3:.*]]: memref<?x2x?xf64>,
+// CHECK-SAME: %[[A4:.*]]: !sparse_tensor.storage_specifier
+// CHECK: return
+func.func @sparse_bcsr(%arg0: tensor<?x2x?x?xf64, #BCSR>) {
+ return
+}
+
// CHECK-LABEL: func @sparse_dcsr(
// CHECK-SAME: %[[A0:.*0]]: memref<?xi32>,
// CHECK-SAME: %[[A1:.*1]]: memref<?xi64>,
More information about the Mlir-commits
mailing list